control_flow.py 57.3 KB
Newer Older
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

S
rename  
sneaxiy 已提交
15
from ..wrapped_decorator import signature_safe_contextmanager
D
dzhwinter 已提交
16

17
from .layer_function_generator import templatedoc
18
from .. import core
19 20 21 22 23 24 25
from ..framework import (
    Program,
    Variable,
    Operator,
    static_only,
    in_dygraph_mode,
)
26
from ..layer_helper import LayerHelper, unique_name
27
from ...utils import (
28 29 30 31 32 33 34 35 36
    assert_same_structure,
    map_structure,
    hold_mutable_vars,
    copy_mutable_vars,
    is_sequence,
    pack_sequence_as,
    flatten,
    to_sequence,
)
Y
yuyang18 已提交
37
import numpy
38
import warnings
L
liym27 已提交
39
from functools import reduce, partial
40 41 42 43 44 45
from ..data_feeder import (
    convert_dtype,
    check_variable_and_dtype,
    check_type,
    check_dtype,
)
46
from ..backward import _infer_var_data_type_shape_
2
201716010711 已提交
47
import paddle
48
from paddle import _C_ops, _legacy_C_ops
D
dzhwinter 已提交
49

Q
QI JUN 已提交
50
__all__ = [
51 52 53
    'Switch',
    'StaticRNN',
    'while_loop',
D
dzhwinter 已提交
54 55
]

Y
Yu Yang 已提交
56

57 58
def select_output(input, outputs, mask):
    """
59
    **select_output**
60 61 62 63 64 65 66 67 68 69 70 71 72 73
    This API takes in one input and multiple outputs and an integer mask. It
    selects the output specified by the mask and copy the input to selected
    output. It is useful in control flow.

    Args:
        input(Variable): The input variable
        outputs(tuple|list): The output variables
        mask(Variable): A tensor containing 1 integer number selecting which
            output to be copied with input

    Returns:
        Variable: The outputs variables
    """
    helper = LayerHelper('select_output', **locals())
74 75 76 77
    check_type(input, 'input', (Variable), 'select_output')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_output')
    check_type(outputs, 'outputs', (list, tuple), 'select_output')

78 79 80 81 82
    helper.append_op(
        type='select_output',
        inputs={'X': input, 'Mask': mask},
        outputs={'Out': outputs},
    )
83 84 85
    return outputs


86 87 88 89 90 91 92
def _select_input_infer_shape(first_shape, second_shape):
    """
    This function infer the output shape by following algorithm:
    1. if the dims is different, raise a error.
    2. compare axis one by one:
        if a == b: we set axis to a
        if a != b: we set axis to -1
93
    for compatibility, non declarative mode, we just return second_shape.
94 95 96 97 98 99 100
    """
    if len(first_shape) != len(second_shape):
        warnings.warn(
            f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
        )
        return second_shape
    out_shape = list(
101 102
        map(lambda a, b: a if a == b else -1, first_shape, second_shape)
    )
103 104 105
    return out_shape


106 107 108
def select_input(inputs, mask):
    """
    **select_input**
109

110 111 112 113 114 115 116 117 118 119 120 121
    This API takes in multiple inputs and uses an integer mask to select one
    input to output. It is useful in control flow.

    Args:
        inputs(tuple|list): The input variables
        mask(Variable): A tensor containing 1 integer number selecting which
            input to output

    Returns:
        Variable: The selected input variable
    """
    helper = LayerHelper('select_input', **locals())
122 123 124
    check_type(inputs, 'inputs', (list, tuple), 'select_input')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')

125
    # Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
126
    # assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
127

128 129 130
    output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
    output_dtype = inputs[1].dtype
    output_type = inputs[1].type
131

132 133 134 135 136 137 138 139
    out = helper.create_variable(
        dtype=output_dtype, shape=output_shape, type=output_type
    )
    helper.append_op(
        type='select_input',
        inputs={'X': inputs, 'Mask': mask},
        outputs={'Out': out},
    )
140 141 142
    return out


143
# (TODO: Mine) There exists dependency. It will be removed later.
144
class BlockGuard:
Y
Yu Yang 已提交
145
    """
146 147 148 149
    BlockGuard class.

    BlockGuard class is used to create a sub-block in a program by
    using the Python `with` keyword.
Y
Yu Yang 已提交
150 151
    """

152 153
    def __init__(self, main_program):
        if not isinstance(main_program, Program):
Y
Yu Yang 已提交
154
            raise TypeError("BlockGuard takes a program")
155
        self.main_program = main_program
Y
Yu Yang 已提交
156 157

    def __enter__(self):
W
Wu Yi 已提交
158
        self.main_program._create_block()
Y
Yu Yang 已提交
159 160

    def __exit__(self, exc_type, exc_val, exc_tb):
W
Wu Yi 已提交
161
        self.main_program._rollback()
Y
Yu Yang 已提交
162 163 164 165 166
        if exc_type is not None:
            return False  # re-raise exception
        return True


167
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang 已提交
168 169 170 171 172
class BlockGuardWithCompletion(BlockGuard):
    """
    BlockGuardWithCompletion class.

    BlockGuardWithCompletion class is used to create an op with a block in a program.
173 174
    """

Y
Yu Yang 已提交
175
    def __init__(self, rnn):
X
Xin Pan 已提交
176
        if not isinstance(rnn, StaticRNN):
X
Xin Pan 已提交
177
            raise TypeError("BlockGuardWithCompletion takes a StaticRNN")
178
        super().__init__(rnn.helper.main_program)
Y
Yu Yang 已提交
179 180 181 182
        self.rnn = rnn

    def __enter__(self):
        self.rnn.status = StaticRNN.IN_RNN_BLOCK
183
        return super().__enter__()
Y
Yu Yang 已提交
184 185

    def __exit__(self, exc_type, exc_val, exc_tb):
Y
Yu Yang 已提交
186 187
        if exc_type is not None:
            return False
Y
Yu Yang 已提交
188
        self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
189
        self.rnn._complete_op()
190
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
191 192


193
class StaticRNNMemoryLink:
Y
Yu Yang 已提交
194
    """
195 196 197 198
    StaticRNNMemoryLink class.

    StaticRNNMemoryLink class is used to create a link between two
    memory cells of a StaticRNN.
Y
yuyang18 已提交
199 200 201 202 203 204 205 206 207


    NOTE: This is a internal data structure of a very low-level API.
    Please use StaticRNN instead.

    Args:
        init(Variable): the initial variable for Memory.
        pre_mem(Variable): the memory variable in previous time step.
        mem(Variable): the memory variable in current time step.
Y
Yu Yang 已提交
208 209 210 211 212 213 214 215
    """

    def __init__(self, init, pre_mem, mem=None):
        self.init = init
        self.pre_mem = pre_mem
        self.mem = mem


216
class StaticRNN:
217
    """
218 219
    :api_attr: Static Graph

220 221
    StaticRNN class.

222 223 224 225 226 227 228
    The StaticRNN can process a batch of sequence data. The first dimension of inputs
    represents sequence length, the length of each input sequence must be equal.
    StaticRNN will unfold sequence into time steps, user needs to define how to process
    each time step during the :code:`with` step.

    Args:
        name (str, optional): Please refer to :ref:`api_guide_Name`, Default None.
C
chengduo 已提交
229 230

    Examples:
231 232
        .. code-block:: python

233
            import paddle
234 235 236 237
            import paddle.fluid as fluid
            import paddle.fluid.layers as layers

            vocab_size, hidden_size=10000, 200
238
            paddle.enable_static()
239
            x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
240
            # create word sequence
241 242 243 244 245
            x_emb = layers.embedding(
                input=x,
                size=[vocab_size, hidden_size],
                dtype='float32',
                is_sparse=False)
246
            # transform batch size to dim 1
247
            x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
248 249 250

            rnn = fluid.layers.StaticRNN()
            with rnn.step():
251
                # mark created x_emb as input, each step process a word
252
                word = rnn.step_input(x_emb)
253
                # create prev memory parameter, batch size comes from word
254
                prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
255
                hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
256 257
                # use hidden to update prev
                rnn.update_memory(prev, hidden)
258
                # mark hidden as output
259
                rnn.step_output(hidden)
260
            # get StaticrNN final output
261
            result = rnn()
C
chengduo 已提交
262

263
    """
264

Y
Yu Yang 已提交
265 266 267 268
    BEFORE_RNN_BLOCK = 0
    IN_RNN_BLOCK = 1
    AFTER_RNN_BLOCK = 2

269
    def __init__(self, name=None):
270
        check_type(name, "name", (str, type(None)), "fluid.layers.StaticRNN")
271
        self.helper = LayerHelper("static_rnn", name=name)
Y
Yu Yang 已提交
272 273 274 275 276 277 278 279
        self.memories = {}  # memory map, from pre_mem.name --> MemoryLink
        self.inputs = []  # input variable list in current block
        self.outputs = []  # output variable list in parent block
        self.status = StaticRNN.BEFORE_RNN_BLOCK  # status flag.
        # sequence length, since it is a static RNN, sequence length are fixed.
        self.seq_len = None

    def step(self):
C
chengduo 已提交
280
        """
281 282
        Define operators in each step. step is used in :code:`with` block, OP in :code:`with` block
        will be executed sequence_len times (sequence_len is the length of input)
C
chengduo 已提交
283
        """
Y
Yang Yang 已提交
284
        return BlockGuardWithCompletion(self)
Y
Yu Yang 已提交
285 286 287 288 289

    def _assert_in_rnn_block_(self, method):
        if self.status != StaticRNN.IN_RNN_BLOCK:
            raise ValueError("You must invoke {0} in rnn block".format(method))

290 291 292 293 294 295 296 297 298
    def memory(
        self,
        init=None,
        shape=None,
        batch_ref=None,
        init_value=0.0,
        init_batch_dim_idx=0,
        ref_batch_dim_idx=1,
    ):
299
        """
C
chengduo 已提交
300 301 302
        Create a memory variable for static rnn.
        If the :code:`init` is not None, :code:`memory` will be initialized by
        this Variable. If the :code:`init` is None, :code:`shape` and :code:`batch_ref`
303 304
        must be set, and this function will create a new variable with shape and batch_ref
        to initialize :code:`init` Variable.
C
chengduo 已提交
305

306
        Args:
307
            init(Variable, optional): Tensor used to init memory. If it is not set,
C
chengduo 已提交
308 309
                :code:`shape` and :code:`batch_ref` must be provided.
                Default: None.
310 311 312 313 314 315 316
            shape(list|tuple): When :code:`init` is None use this arg to initialize memory shape.
            NOTE the shape does not contain batch_size. Default: None.
            batch_ref(Variable, optional): When :code:`init` is None, memory's batch size will
            be set as batch_ref's ref_batch_dim_idx value. Default: None.
            init_value(float, optional): When :code:`init` is None, used to init memory's value. Default: 0.0.
            init_batch_dim_idx(int, optional): the batch_size axis of the :code:`init` Variable. Default: 0.
            ref_batch_dim_idx(int, optional): the batch_size axis of the :code:`batch_ref` Variable. Default: 1.
C
chengduo 已提交
317 318

        Returns:
319 320 321 322 323
            Variable: The memory variable.

        Examples 1:
            .. code-block:: python

324
                import paddle
325 326 327 328
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
329
                paddle.enable_static()
330
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
331 332 333 334 335 336 337
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
338
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
339 340 341 342 343 344 345

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
346
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
347 348
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
349 350 351


        Examples 2:
352 353
            .. code-block:: python

354
                import paddle
355 356 357
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers
                vocab_size, hidden_size=10000, 200
358
                paddle.enable_static()
359
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
360 361 362 363 364 365 366
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
367
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
G
GGBond8488 已提交
368
                boot_memory = paddle.static.data(name='boot', shape=[-1, hidden_size], dtype='float32', lod_level=1)
369 370 371 372 373 374
                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # init memory
                        prev = rnn.memory(init=boot_memory)
C
Charles-hit 已提交
375
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
376 377
                        # update hidden with prev
                        rnn.update_memory(prev, hidden)
378

379
        """
Y
Yu Yang 已提交
380
        self._assert_in_rnn_block_('memory')
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
        check_type(
            init,
            "init",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            shape,
            "shape",
            (list, tuple, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            batch_ref,
            "batch_ref",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
Y
Yu Yang 已提交
399
        if init is None:
400
            if shape is None or batch_ref is None:
Y
Yu Yang 已提交
401
                raise ValueError(
402 403
                    "if init is None, memory at least need shape and batch_ref"
                )
404
            parent_block = self._parent_block()
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
            var_name = unique_name.generate_with_ignorable_key(
                "@".join([self.helper.name, "memory_boot"])
            )
            boot_var = parent_block.create_var(
                name=var_name,
                shape=shape,
                dtype=batch_ref.dtype,
                persistable=False,
            )

            parent_block.append_op(
                type="fill_constant_batch_size_like",
                inputs={'Input': [batch_ref]},
                outputs={'Out': [boot_var]},
                attrs={
                    'value': init_value,
                    'shape': boot_var.shape,
                    'dtype': boot_var.dtype,
                    'input_dim_idx': ref_batch_dim_idx,
                    'output_dim_idx': init_batch_dim_idx,
                },
            )
Y
Yu Yang 已提交
427 428 429 430

            return self.memory(init=boot_var)
        else:
            pre_mem = self.helper.create_variable(
431 432 433
                name=unique_name.generate_with_ignorable_key(
                    "@".join([self.helper.name, "mem"])
                ),
F
fengjiayi 已提交
434
                dtype=init.dtype,
435 436 437 438 439
                shape=init.shape,
            )
            self.memories[pre_mem.name] = StaticRNNMemoryLink(
                init=init, pre_mem=pre_mem
            )
Y
Yu Yang 已提交
440 441 442
            return pre_mem

    def step_input(self, x):
C
chengduo 已提交
443 444 445 446 447 448 449 450
        """
        Mark a sequence as a StaticRNN input.

        Args:
            x(Variable): The input sequence, the shape of x
                should be [seq_len, ...].

        Returns:
451 452 453 454 455
            Variable: The current time step data in the input sequence.

        Examples:
            .. code-block:: python

456
                import paddle
457 458 459 460
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
461
                paddle.enable_static()
462
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
463 464 465 466 467 468 469
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
470
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
471 472 473 474 475 476 477

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
478
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
479 480
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
481

C
chengduo 已提交
482
        """
Y
Yu Yang 已提交
483
        self._assert_in_rnn_block_('step_input')
484
        check_type(x, "x", Variable, "fluid.layers.StaticRNN.step_input")
Y
Yu Yang 已提交
485
        if self.seq_len is None:
Y
Yu Yang 已提交
486
            self.seq_len = x.shape[0]
487
        elif x.shape[0] != -1 and self.seq_len != x.shape[0]:
Y
Yu Yang 已提交
488 489
            raise ValueError("Static RNN only take fix seq_len input")

490 491 492
        ipt = self.helper.create_variable(
            name=x.name, dtype=x.dtype, shape=list(x.shape[1:]), type=x.type
        )
Y
Yu Yang 已提交
493 494 495 496
        self.inputs.append(ipt)
        return ipt

    def step_output(self, o):
C
chengduo 已提交
497 498 499 500 501 502 503 504
        """
        Mark a sequence as a StaticRNN output.

        Args:
            o(Variable): The output sequence.

        Returns:
            None.
505 506 507 508

        Examples:
            .. code-block:: python

509
                import paddle
510 511 512 513
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
514
                paddle.enable_static()
515
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
516 517 518 519 520 521 522
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
523
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
524 525 526 527 528 529 530

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
531
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
532 533 534 535 536
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        rnn.step_output(hidden)

                result = rnn()
537

C
chengduo 已提交
538
        """
Y
Yu Yang 已提交
539
        self._assert_in_rnn_block_('step_output')
540
        check_type(o, "o", Variable, "fluid.layers.StaticRNN.step_output")
Y
Yu Yang 已提交
541

X
Xin Pan 已提交
542
        tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype)
543 544 545 546 547 548
        self.helper.append_op(
            type='rnn_memory_helper',
            inputs={'X': [o]},
            outputs={'Out': tmp_o},
            attrs={'dtype': o.dtype},
        )
Y
Yu Yang 已提交
549

550 551 552 553 554
        out_var = self._parent_block().create_var(
            name=tmp_o.name,
            shape=[self.seq_len] + list(tmp_o.shape),
            dtype=tmp_o.dtype,
        )
Y
Yu Yang 已提交
555 556 557 558

        self.outputs.append(out_var)

    def output(self, *outputs):
C
chengduo 已提交
559 560 561 562
        """
        Mark the StaticRNN output variables.

        Args:
563
            outputs: The output Tensor, can mark multiple variables as output
C
chengduo 已提交
564 565 566

        Returns:
            None
567 568 569 570

        Examples:
            .. code-block:: python

571
                import paddle
572 573 574 575
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
576
                paddle.enable_static()
577
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
578 579 580 581 582 583 584
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
585
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
586 587 588 589 590 591 592

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
593
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
594 595 596 597 598 599
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        # mark each step's hidden and word as output
                        rnn.output(hidden, word)

                result = rnn()
C
chengduo 已提交
600
        """
Y
Yu Yang 已提交
601 602 603 604
        for each in outputs:
            self.step_output(each)

    def update_memory(self, mem, var):
C
chengduo 已提交
605
        """
606
        Update the memory from :code:`mem` to :code:`var`.
C
chengduo 已提交
607 608 609

        Args:
            mem(Variable): the memory variable.
610
            var(Variable): the plain variable generated in RNN block, used to update memory.
T
tianshuo78520a 已提交
611
                           var and mem should have same dims and data type.
C
chengduo 已提交
612 613 614

        Returns:
            None
615

C
chengduo 已提交
616
        """
617 618
        check_type(mem, "mem", Variable, "fluid.layers.StaticRNN.update_memory")
        check_type(var, "var", Variable, "fluid.layers.StaticRNN.update_memory")
Y
Yu Yang 已提交
619 620
        self.memories[mem.name].mem = var

621
    def _parent_block(self):
622
        prog = self.helper.main_program
Y
Yu Yang 已提交
623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
        parent_idx = prog.current_block().parent_idx
        assert parent_idx >= 0
        parent_block = prog.block(parent_idx)
        return parent_block

    def __call__(self, *args, **kwargs):
        if self.status != StaticRNN.AFTER_RNN_BLOCK:
            raise ValueError("RNN output can only be retrieved after rnn block")
        if len(self.outputs) == 0:
            raise ValueError("RNN has no output")
        elif len(self.outputs) == 1:
            return self.outputs[0]
        else:
            return self.outputs

638
    def _complete_op(self):
639 640
        main_program = self.helper.main_program
        rnn_block = main_program.current_block()
641
        parent_block = self._parent_block()
Y
Yu Yang 已提交
642 643 644 645 646 647 648 649 650 651 652 653 654 655

        local_inputs = set()

        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for oname in op.output_names:
                for out_var_name in op.output(oname):
                    local_inputs.add(out_var_name)

        for var in self.inputs:
            local_inputs.add(var.name)
        for m in self.memories:
            local_inputs.add(m)

C
chengduo 已提交
656 657 658
        # NOTE(zcd): the params have two categories of variables.
        #   - the variables that are the out of StaticRnn.
        #   - the variables that are the parameters of some layers, for example, conv2d.
Y
Yu Yang 已提交
659 660 661 662 663 664 665 666
        params = list()
        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for iname in op.input_names:
                for in_var_name in op.input(iname):
                    if in_var_name not in local_inputs:
                        params.append(in_var_name)

667 668 669
        parameters = [
            parent_block._find_var_recursive(name) for name in set(params)
        ]
Y
Yu Yang 已提交
670 671

        step_scope = parent_block.create_var(
672 673
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yu Yang 已提交
674 675 676 677

        inlinks = [parent_block.var(i.name) for i in self.inputs]
        outlinks = self.outputs

C
chengduo 已提交
678
        # NOTE(zcd): the states maybe empty in some case.
Y
Yu Yang 已提交
679 680 681
        boot_memories = []
        pre_memories = []
        memories = []
682
        for _, mem in self.memories.items():
Y
Yu Yang 已提交
683 684
            boot_memories.append(mem.init)
            pre_memories.append(mem.pre_mem.name)
685 686 687
            assert (
                mem.mem is not None
            ), "%s should be updated in every step." % (mem.init.name)
Y
Yu Yang 已提交
688 689
            mem_var = rnn_block.var(mem.mem.name)
            assert isinstance(mem_var, Variable)
X
Xin Pan 已提交
690
            new_mem = self.helper.create_variable_for_type_inference(
691 692 693 694 695 696 697 698
                dtype=mem_var.dtype
            )
            rnn_block.append_op(
                type='rnn_memory_helper',
                inputs={'X': [mem_var]},
                outputs={'Out': [new_mem]},
                attrs={'dtype': mem_var.dtype},
            )
Y
Yu Yang 已提交
699 700 701

            memories.append(new_mem.name)

702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
        parent_block.append_op(
            type='recurrent',
            inputs={
                'inputs': inlinks,
                'initial_states': boot_memories,
                'parameters': parameters,
            },
            outputs={'outputs': outlinks, 'step_scopes': [step_scope]},
            attrs={
                'has_states': len(pre_memories) > 0,
                'ex_states': pre_memories,
                'states': memories,
                'sub_block': rnn_block,
            },
        )
Y
Yu Yang 已提交
717 718


719
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang(Tony) 已提交
720 721 722 723
class WhileGuard(BlockGuard):
    def __init__(self, while_op):
        if not isinstance(while_op, While):
            raise TypeError("WhileGuard takes a while op")
724
        super().__init__(while_op.helper.main_program)
Y
Yang Yang(Tony) 已提交
725 726 727 728
        self.while_op = while_op

    def __enter__(self):
        self.while_op.status = While.IN_WHILE_BLOCK
729
        return super().__enter__()
Y
Yang Yang(Tony) 已提交
730 731 732 733 734

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            return False
        self.while_op.status = While.AFTER_WHILE_BLOCK
735
        self.while_op._complete()
736
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yang Yang(Tony) 已提交
737 738


739
# (TODO: Mine) There exists dependency. It will be removed later.
740 741 742
def get_inputs_outputs_in_block(
    current_block, inner_inputs, inner_outputs, helper
):
743 744 745 746 747 748 749 750
    """
    Find inputs and outputs in current control flow block.
    :param current_block: Current control flow block.
    :param inner_inputs: Input var name of ops in current block.
    :param inner_outputs: Output var name of ops in current block.
    :return: inner_inputs, inner_outputs
    """

751 752 753 754 755 756 757 758 759 760 761 762 763
    def is_ignore_vars(op, var_name):
        # NOTE(dev): There are some persistable var created in some non-standard API
        # such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
        # Input and Output. This var shall not be considered as a loop_var in
        # control_flow.
        IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]}
        if op.type in IGNORE_VAR_NAMES:
            var_names = IGNORE_VAR_NAMES[op.type]
            for name in var_names:
                if name in var_name:
                    return True
        return False

764 765 766 767 768 769 770 771
    # Step1: update inner_inputs and inner_outputs
    # NOTE: Here assumes that all variables are input or output of Ops,
    # but some variables are created without appendding a real op.
    # For example, in `arr = create_array(dtype)`, `arr` is not a output of a op.
    for op in current_block.ops:
        assert isinstance(op, Operator)
        for iname in op.input_names:
            for in_var_name in op.input(iname):
772
                if in_var_name not in inner_outputs and not is_ignore_vars(
773 774
                    op, in_var_name
                ):
775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
                    inner_inputs.add(in_var_name)

        for oname in op.output_names:
            for out_var_name in op.output(oname):
                inner_outputs.add(out_var_name)

    # Step2: Remove LOD_TENSOR_ARRAY created in current control flow block.
    remove_inner_inputs = set()
    parent_block = helper.main_program.block(current_block.parent_idx)

    for in_var_name in inner_inputs:
        parent_block_var = parent_block._find_var_recursive(in_var_name)
        current_block_var = None
        if current_block.has_var(in_var_name):
            current_block_var = current_block.var(in_var_name)
790 791 792 793 794
        if (
            not parent_block_var
            and current_block_var
            and current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
        ):
795 796 797 798 799 800 801
            remove_inner_inputs.add(in_var_name)

    inner_inputs = inner_inputs - remove_inner_inputs

    return inner_inputs, inner_outputs


802
# (TODO: Mine) There exists dependency. It will be removed later.
803
class While:
X
Xin Pan 已提交
804
    """
805
    :api_attr: Static Graph
806

807
    while loop control flow. Repeat while body until cond is False.
X
Xin Pan 已提交
808

809 810 811 812
    Note:
        A new OP :ref:`api_fluid_layers_while_loop` is highly recommended instead of ``While`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_while_loop` is easier to use and is called with less code but does the same thing as ``While`` .

813 814 815 816 817 818
    Notice:
        Local variables created in ``While`` are similar to that created in while of C++, and cannot be referenced externally.
        As a result, they cannot be obtained through ``fetch_list`` of ``Executor``. If you would like to access the variable
        out of ``while`` , PaddlePaddle provides ``assign`` API to assign local variables to external. Please refer to example
        code 2 or refer to `issue#22724 <https://github.com/PaddlePaddle/Paddle/issues/22724>`_.

X
Xin Pan 已提交
819
    Args:
820
        cond(Variable): A Tensor whose data type is bool controlling whether to continue looping.
G
guofei 已提交
821
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
822
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
X
Xin Pan 已提交
823

824
    Examples 1:
X
Xin Pan 已提交
825
          .. code-block:: python
826

827
            import paddle.fluid as fluid
828
            import paddle
829 830
            import numpy as np

831
            i = paddle.full(shape=[1], dtype='int64', fill_value=0)           # loop counter
832

833
            loop_len = paddle.full(shape=[1],dtype='int64', fill_value=10)    # loop length
834

L
LiYuRio 已提交
835
            cond = paddle.less_than(x=i, y=loop_len)
836
            while_op = fluid.layers.While(cond=cond)
837
            with while_op.block():
838
                i = paddle.increment(x=i, value=1)
L
LiYuRio 已提交
839
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
840 841 842 843 844

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[i])
845 846 847 848 849 850
            print(res) # [array([10])]


    Examples 2:
          .. code-block:: python

L
LiYuRio 已提交
851
            import paddle
852 853 854
            import paddle.fluid as fluid
            import numpy as np

855
            paddle.enable_static()
856 857 858
            i = paddle.full(shape=[1], dtype='int64', fill_value=0)
            loop_len = paddle.full(shape=[1], dtype='int64', fill_value=10)
            one = paddle.full(shape=[1], dtype='float32', fill_value=1)
859
            data = paddle.static.data(name='data', shape=[1], dtype='float32')
860
            sums = paddle.full(shape=[1], dtype='float32', fill_value=0)  # Define the variable to be obtained ouside of While, which name should be different from the variable inside the While to be obtained
861

L
LiYuRio 已提交
862
            cond = paddle.less_than(x=i, y=loop_len)
863 864
            while_op = fluid.layers.While(cond=cond)
            with while_op.block():
H
HongyuJia 已提交
865
                sums_tensor = paddle.add(x=data, y=data)
866
                fluid.layers.assign(sums_tensor, sums)  # Update the value of sums_tensor defined in While to the sums which defined outside of While through layers.assign
867
                i = paddle.increment(x=i, value=1)
H
HongyuJia 已提交
868
                data = paddle.add(x=data, y=one)
L
LiYuRio 已提交
869
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
870 871 872 873 874 875

            feed_data = np.ones(1).astype('float32')
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())
            res = exe.run(fluid.default_main_program(), feed={'data': feed_data}, fetch_list=sums)
            print(res[0])  # [2.]    # Because the data in While does not update the value outside the While, the value of sums is [2.] after the loop
X
Xin Pan 已提交
876 877
    """

Y
Yang Yang(Tony) 已提交
878 879 880 881
    BEFORE_WHILE_BLOCK = 0
    IN_WHILE_BLOCK = 1
    AFTER_WHILE_BLOCK = 2

C
chengduo 已提交
882
    def __init__(self, cond, is_test=False, name=None):
883
        self.helper = LayerHelper("while", name=name)
Y
Yang Yang(Tony) 已提交
884
        self.status = While.BEFORE_WHILE_BLOCK
885
        check_variable_and_dtype(cond, 'cond', ['bool'], 'fluid.layers.While')
Y
Yang Yang(Tony) 已提交
886
        if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
887
            raise TypeError(
888 889 890 891
                "condition expected shape as [1], but given shape as {0}.".format(
                    list(cond.shape)
                )
            )
Y
Yang Yang(Tony) 已提交
892
        self.cond_var = cond
C
chengduo 已提交
893
        self.is_test = is_test
Y
Yang Yang(Tony) 已提交
894 895 896 897

    def block(self):
        return WhileGuard(self)

898
    def _complete(self):
Y
Yang Yang(Tony) 已提交
899 900
        main_program = self.helper.main_program
        while_block = main_program.current_block()
901
        parent_block = main_program.block(
902 903
            main_program.current_block().parent_idx
        )
Y
Yang Yang(Tony) 已提交
904 905 906

        inner_outputs = {self.cond_var.name}
        x_name_list = set()
907
        x_name_list, inner_outputs = get_inputs_outputs_in_block(
908 909
            while_block, x_name_list, inner_outputs, self.helper
        )
Y
Yang Yang(Tony) 已提交
910 911 912

        out_vars = []
        for inner_out_name in inner_outputs:
X
Xin Pan 已提交
913 914 915
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_vars.append(inner_var)
Y
Yang Yang(Tony) 已提交
916

917
        x_name_list |= set(map(lambda x: x.name, out_vars))
918 919 920
        # NOTE(dev): cond_var has been contained in Input('Condition'), so
        # we remove it from Input('X')
        x_name_list -= {self.cond_var.name}
921

Y
Yang Yang(Tony) 已提交
922
        step_scope = parent_block.create_var(
923 924
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yang Yang(Tony) 已提交
925 926 927 928

        parent_block.append_op(
            type='while',
            inputs={
929 930 931 932 933
                'X': [
                    parent_block._var_recursive(x_name)
                    for x_name in x_name_list
                ],
                'Condition': [self.cond_var],
934
            },
935 936 937
            outputs={'Out': out_vars, 'StepScopes': [step_scope]},
            attrs={'sub_block': while_block, "is_test": self.is_test},
        )
Y
Yang Yang(Tony) 已提交
938 939


940
support_ret_buildin_type = (bool, float, int)
941 942


943
# (TODO: Mine) There exists dependency. It will be removed later.
944
def assign_skip_lod_tensor_array(input, output):
945
    """
946
    Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
947
    """
948 949

    def has_shape_diff(x_var, y_var):
950 951
        if len(x_var.shape) != len(y_var.shape):
            return True
952
        for x_dim, y_dim in zip(x_var.shape, y_var.shape):
953 954
            if x_dim != y_dim and -1 not in [x_dim, y_dim]:
                return True
955 956
        return False

W
wanghuancoder 已提交
957
    if not isinstance(input, (Variable, core.eager.Tensor)):
958
        if isinstance(output, Variable) and isinstance(
959 960
            input, support_ret_buildin_type
        ):
961
            paddle.assign(input, output)
962 963
        else:
            output = input
964 965
        return

966 967
    if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        main_program = input.block.program
968
        parent_block = main_program.block(
969 970
            main_program.current_block().parent_idx
        )
971
        if parent_block and not parent_block._find_var_recursive(input.name):
972
            paddle.assign(input, output)
973
    else:
974 975 976 977 978
        if (
            isinstance(output, Variable)
            and isinstance(input, Variable)
            and has_shape_diff(input, output)
        ):
979
            warnings.warn(
980 981 982 983
                "In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right.".format(
                    input.shape, output.shape
                )
            )
984
        paddle.assign(input, output)
985 986


987
# (TODO: Mine) There exists dependency (jit.dy2static.convert_operators). It will be removed later.
G
guofei 已提交
988
def while_loop(cond, body, loop_vars, is_test=False, name=None):
G
guofei 已提交
989
    """
990 991
    :api_attr: Static Graph

G
guofei 已提交
992 993
    while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.

994 995 996 997
    Notice:
        Local variables defined in ``body`` cannot be obtained through ``fetch_list`` of ``Executor`` , variables should
        be defined outside ``body`` and placed in ``loop_vars`` for looping, then these variables can be fetched by ``fetch_list`` .

G
guofei 已提交
998
    Args:
999
        cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
1000
            as many arguments as ``loop_vars`` .
1001 1002 1003
        body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
            (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
        loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
G
guofei 已提交
1004
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
G
guofei 已提交
1005 1006
        name(str, optional): Normally there is no need for users to set this property. For more information, please
            refer to :ref:`api_guide_Name`. Default is None.
1007

G
guofei 已提交
1008
    Returns:
C
Chen Long 已提交
1009
        A list or tuple of Tensors or LoDTensorArrays which returned by ``body`` .
G
guofei 已提交
1010 1011 1012 1013

    Examples:
        .. code-block:: python

1014 1015 1016
            import paddle
            paddle.enable_static()

1017 1018
            def cond(i, ten):
                return i < ten
G
guofei 已提交
1019

1020 1021 1022
            def body(i, ten):
                i = i + 1
                return [i, ten]
G
guofei 已提交
1023

C
Chen Long 已提交
1024 1025 1026 1027 1028 1029
            main_program = paddle.static.default_main_program()
            startup_program = paddle.static.default_startup_program()
            with paddle.static.program_guard(main_program, startup_program):
                i = paddle.full(shape=[1], fill_value=0, dtype='int64')     # loop counter
                ten = paddle.full(shape=[1], fill_value=10, dtype='int64')  # loop length
                i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
1030

C
Chen Long 已提交
1031
                exe = paddle.static.Executor(paddle.CPUPlace())
1032
                res = exe.run(main_program, feed={}, fetch_list=[i])
G
guofei 已提交
1033 1034 1035 1036 1037 1038 1039 1040
                print(res) # [array([10])]
    """
    helper = LayerHelper('while_loop', **locals())

    if not callable(cond):
        raise TypeError("cond in while_loop should be callable")
    if not callable(body):
        raise TypeError("body in while_loop should be callable")
1041
    check_type(loop_vars, 'loop_vars', (list, tuple), 'fluid.layers.while_loop')
G
guofei 已提交
1042 1043 1044 1045
    if len(loop_vars) == 0:
        raise ValueError("loop_vars in while_loop should not be empty")

    pre_cond = cond(*loop_vars)
1046

G
guofei 已提交
1047 1048
    if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
        raise TypeError(
1049
            "the shape of the variable returned by cond should be [1],"
1050 1051
            "but given shape as {0}.".format(list(pre_cond.shape))
        )
G
guofei 已提交
1052

姜永久 已提交
1053
    if in_dygraph_mode():
1054
        now_cond = pre_cond.item()
1055
        while now_cond:
1056 1057 1058 1059 1060 1061
            output_vars = body(*loop_vars)
            if not isinstance(output_vars, (list, tuple)):
                output_vars = [output_vars]
            if len(output_vars) != len(loop_vars):
                raise ValueError(
                    "body in while_loop should return the same arity "
1062 1063
                    "(length and structure) and types as loop_vars"
                )
1064
            now_cond = cond(*output_vars).item()
1065
            map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
1066
        return loop_vars
姜永久 已提交
1067
    else:
1068 1069 1070 1071 1072 1073
        check_variable_and_dtype(
            pre_cond,
            'var of cond returned',
            ['bool'],
            'fluid.layers.while_loop',
        )
姜永久 已提交
1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
        while_loop_block = While(pre_cond, is_test, name)
        has_mutable_vars_in_loop = hold_mutable_vars(loop_vars)
        with while_loop_block.block():
            # If a variable with mutable type is included in loop_vars, like `dict/list`,
            # modifying it in the body function will cause origin variable to be modified
            # synchronously. This will raise an assignment error out of while block.
            # Here we make a copy of the mutable vars to avoid this problem.
            if has_mutable_vars_in_loop:
                new_loop_vars = copy_mutable_vars(loop_vars)
                output_vars = body(*new_loop_vars)
            else:
                output_vars = body(*loop_vars)
            if not isinstance(output_vars, (list, tuple)):
                output_vars = [output_vars]
            try:
                loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
                assert_same_structure(output_vars, loop_vars, check_types=False)
            except ValueError as e:
                raise ValueError(
                    "body in while_loop should return the same arity "
                    "(length and structure) as loop_vars: {0}".format(e)
                )
            now_cond = cond(*output_vars)
            map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
1098
            paddle.assign(now_cond, pre_cond)
姜永久 已提交
1099
        return loop_vars
G
guofei 已提交
1100 1101


1102
# (TODO: Mine) There exists dependency. It will be removed later.
1103
def _deal_with_undefined_var(output_vars, loop_vars):
1104 1105 1106 1107 1108 1109 1110
    """Deal with undefined var cases, We create undefined variable based on the results of body().
    In Dy2Static, we use undefined var to represent the var created in control flow. This function
    expand the loop_vars and replace original loop_vars.
    1. UndefinedVar = Variable      # create a variable
    2. UndefinedVar = None          # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
    3. UndefinedVar = List(int)     # create a list of variable
    4. UndefinedVar = value         # create a variable
1111
    """
1112
    from paddle.jit.dy2static.utils import (
1113 1114 1115
        UndefinedVar,
        create_undefined_variable,
    )
1116 1117

    def create_var_like(o_var):
1118 1119 1120 1121
        if (
            isinstance(o_var, (Variable,) + support_ret_buildin_type)
            or o_var is None
        ):
1122
            return create_undefined_variable()
1123
        if is_sequence(o_var):
1124
            """
1125 1126 1127
            Create a complex container class inside the body of while, including Python list and python Dict
            """
            return map_structure(lambda x: create_undefined_variable(), o_var)
1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140

    if len(output_vars) != len(loop_vars):
        raise ValueError("The length of loop_vars should be the same.")

    results = []
    for o_var, l_var in zip(output_vars, loop_vars):
        if isinstance(l_var, UndefinedVar) or l_var is None:
            results.append(create_var_like(o_var))
        else:
            results.append(l_var)
    return results


Y
Yu Yang 已提交
1141
class ConditionalBlockGuard(BlockGuard):
F
fengjiayi 已提交
1142
    """
1143 1144 1145
    ConditionalBlockGuard is derived from BlockGuard. It is dedicated for
    holding a ConditionalBlock, and helping users entering and exiting the
    ConditionalBlock via Python's 'with' keyword. However, ConditionalBlockGuard
F
fengjiayi 已提交
1146 1147 1148
    is generally an internal component of IfElse, users should not use it directly.
    """

Y
Yu Yang 已提交
1149
    def __init__(self, block):
1150
        check_type(block, "block", ConditionalBlock, "ConditionalBlockGuard")
1151
        super().__init__(block.helper.main_program)
Y
Yu Yang 已提交
1152 1153 1154
        self.block = block

    def __enter__(self):
1155
        return super().__enter__()
Y
Yu Yang 已提交
1156 1157 1158

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.block.complete()
1159
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
1160 1161


1162
class ConditionalBlock:
Y
Yan Chunwei 已提交
1163 1164 1165 1166 1167 1168 1169 1170
    '''
    **ConditionalBlock**

    ConditionalBlock is an operator that bind a block to a specific condition,
    if the condition matches, the corresponding block will be executed.

    Args:
        inputs (Variable): bool conditions.
T
tianshuo78520a 已提交
1171
        is_scalar_condition (bool): whether the branch is controlled by a scalar.
Y
Yan Chunwei 已提交
1172 1173 1174 1175 1176
        name(str): name of this ConditionalBlock.

    Examples:
        .. code-block:: python

L
LiYuRio 已提交
1177
             import paddle
1178
             import paddle.fluid as fluid
L
LiYuRio 已提交
1179
             cond = paddle.less_than(x=label, y=limit)
Y
Yan Chunwei 已提交
1180 1181 1182 1183 1184 1185 1186 1187 1188 1189
             true_image, false_image = layers.split_lod_tensor(
                 input=image, mask=cond)
             true_cond = layers.ConditionalBlock([true_image])

             with true_cond.block():
                 ...
             with false_cond.block():
                 ...
    '''

1190
    def __init__(self, inputs, is_scalar_condition=False, name=None):
Y
Yu Yang 已提交
1191
        for each_input in inputs:
1192
            check_type(each_input, "input", Variable, "ConditionalBlock")
Y
Yu Yang 已提交
1193
        self.inputs = inputs
1194
        self.is_scalar_condition = is_scalar_condition
1195
        self.helper = LayerHelper('conditional_block', name=name)
Y
Yu Yang 已提交
1196 1197 1198 1199 1200 1201 1202 1203 1204 1205

    def block(self):
        return ConditionalBlockGuard(self)

    def complete(self):
        inside_block = self.helper.main_program.current_block()
        parent_block = self.helper.main_program.block(inside_block.parent_idx)

        intermediate = set()
        params = set()
1206 1207 1208
        params, intermediate = get_inputs_outputs_in_block(
            inside_block, params, intermediate, helper=self.helper
        )
Y
Yu Yang 已提交
1209

1210 1211 1212
        # Todo(liym27) Here assume that all params are in recursive parent block
        # but when minimize() called in control flow, some params may be in
        # conditional grad block
Y
Yu Yang 已提交
1213
        param_list = [
W
Wu Yi 已提交
1214
            parent_block._var_recursive(each_name) for each_name in params
Y
Yu Yang 已提交
1215 1216
        ]

X
Xin Pan 已提交
1217 1218 1219 1220 1221
        out_list = []
        for inner_out_name in intermediate:
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_list.append(inner_var)
Y
Yu Yang 已提交
1222 1223

        step_scope = parent_block.create_var(
1224 1225
            type=core.VarDesc.VarType.STEP_SCOPES
        )
1226
        conditional_block_op = parent_block.append_op(
Y
Yu Yang 已提交
1227 1228
            type='conditional_block',
            inputs={
1229 1230
                'Cond': self.inputs,
                'Input': param_list,
Y
Yu Yang 已提交
1231
            },
1232
            outputs={'Out': out_list, 'Scope': [step_scope]},
1233 1234
            attrs={
                'sub_block': inside_block,
1235 1236 1237
                'is_scalar_condition': self.is_scalar_condition,
            },
        )
1238

1239
        if self.need_append_conditional_block_grad(inside_block):
1240 1241 1242
            self.append_conditional_block_grad(
                parent_block, inside_block, conditional_block_op
            )
1243 1244 1245

    def need_append_conditional_block_grad(self, inside_block):
        grad_sub_block_idx = inside_block.backward_block_idx
1246
        inside_block_idx = inside_block.idx
1247

1248 1249
        # if inside_block have grad_block and grad_block is not itself,
        # we will append conditional block grad.
1250 1251 1252
        return (
            grad_sub_block_idx != -1 and grad_sub_block_idx != inside_block_idx
        )
1253

1254 1255 1256
    def append_conditional_block_grad(
        self, parent_block, inside_block, conditional_block_op
    ):
1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291
        '''
        Append op `conditional_block_grad` manually.
        When `optimizer.minimize/append_backward` is called in Paddle control flow,
        grad ops will be appended before appending op `conditional_block` so that
        op `conditional_block_grad` can't be appended when calling
        `optimizer.minimize/append_backward`. After appending op `conditional_block`,
        `conditional_block_grad` is appended manually.

        Args:
            parent_block (Block): The block that `conditional_block_op` blongs to.
            inside_block (Block): The sub block of `conditional_block_op`.
            conditional_block_op (Operator): The forward op conditional_block.
        '''

        grad_sub_block_idx = inside_block.backward_block_idx
        grad_sub_block = self.helper.main_program.block(grad_sub_block_idx)

        intermediate = set()
        params = set()

        for each_op in grad_sub_block.ops:
            assert isinstance(each_op, Operator)
            for iname in each_op.input_names:
                for in_var_name in each_op.input(iname):
                    if in_var_name not in intermediate:
                        params.add(in_var_name)

            for oname in each_op.output_names:
                for out_var_name in each_op.output(oname):
                    intermediate.add(out_var_name)

        param_list = []
        for inner_input_name in params:
            inner_var = parent_block._find_var_recursive(inner_input_name)
            if inner_var:
1292
                param_list.append(inner_var.name)
1293 1294

        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
1295 1296
            conditional_block_op.desc, set(), [grad_sub_block.desc]
        )
1297 1298 1299 1300 1301 1302 1303 1304 1305

        # append op_desc in grad_op_descs to target_block
        op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        new_op_desc = parent_block.desc.append_op()
        new_op_desc.copy_from(grad_op_desc[0])
        new_op_desc._set_attr(op_role_attr_name, backward)
        # set input and output manually
        new_op_desc.set_input('Input', param_list)
1306 1307 1308
        new_op_desc.set_output(
            'Input@GRAD', [param + "@GRAD" for param in param_list]
        )
1309 1310 1311

        new_vars = set()
        for grad_var_name in new_op_desc.output_arg_names():
1312 1313 1314 1315
            if (
                grad_sub_block.desc.has_var_recursive(grad_var_name.encode())
                or grad_var_name == core.empty_var_name()
            ):
1316
                continue
1317
            grad_sub_block.desc.var(grad_var_name.encode())
1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331
            new_vars.add(grad_var_name)
            if grad_var_name not in op_grad_to_var:
                continue

        # infer_shape and infer_type
        new_op_desc.infer_var_type(grad_sub_block.desc)
        new_op_desc.infer_shape(grad_sub_block.desc)

        for arg in new_op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_shape_(arg, grad_sub_block)

        self.helper.main_program._sync_with_cpp()

1332

1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350
def _to_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return [x]
    return to_sequence(x)


def _is_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return False
    return is_sequence(x)


1351
def expand_undefined_var(nest1, nest2, names):
1352 1353 1354 1355
    """TODO: make this function recursively.
    nest1: Var1, (UndefinedVar, [1,2,3])
    nest2: Var2, ([1,2,3,4], UndefinedVar)
    In this case, we should not expand recursively.
1356
    """
1357
    from paddle.jit.dy2static.utils import UndefinedVar
1358
    from paddle.jit.dy2static.return_transformer import (
1359 1360
        RETURN_VALUE_PREFIX,
    )
1361 1362

    def pack_undefined_var_as(seq):
1363 1364 1365
        return pack_sequence_as(
            seq, [UndefinedVar("padding") for i in flatten(seq)]
        )
1366

1367
    def map_fn(n1, n2, name, order):
1368 1369 1370
        if not name.startswith(RETURN_VALUE_PREFIX) and (
            isinstance(n1, UndefinedVar) or n1 is None
        ):
1371 1372 1373 1374 1375 1376
            if n1 is None and n2 is not None:
                if order == 0:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1377 1378 1379
                            name, type(n1), n1, type(n2), n2
                        )
                    )
1380 1381 1382 1383 1384
                else:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1385 1386 1387
                            name, type(n2), n2, type(n1), n1
                        )
                    )
1388 1389 1390 1391
            return pack_undefined_var_as(n2)
        return n1

    nest1_out = list(
1392 1393
        map(
            map_fn,
1394 1395 1396 1397
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(names),
            [0 for i in _to_sequence_except_dict(names)],
1398 1399
        )
    )
1400
    nest2_out = list(
1401 1402
        map(
            map_fn,
1403 1404 1405 1406
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(names),
            [1 for i in _to_sequence_except_dict(names)],
1407 1408
        )
    )
1409
    if not _is_sequence_except_dict(nest1):
1410
        nest1_out = nest1_out[0]
1411
    if not _is_sequence_except_dict(nest2):
1412
        nest2_out = nest2_out[0]
1413 1414 1415
    return nest1_out, nest2_out


Q
qizhaoaoe 已提交
1416
# TODO: It will be deleted later.
1417
class Switch:
Q
qiaolongfei 已提交
1418
    """
1419
    :api_attr: Static Graph
Q
qiaolongfei 已提交
1420

1421 1422 1423 1424 1425
    This class is used to implement Switch branch control function.
    Switch branch contains several case branches and one default branch.
    Switch control flow checks whether the case branch conditions are satisfied in turn,
    and only executes the statement after the first case branch that satisfies the conditions.
    If there is no case branch that satisfies the condition,
1426 1427
    only the statement following the default branch is executed.

1428 1429 1430 1431
    Note:
        A new OP :ref:`api_fluid_layers_case` is highly recommended instead of ``Switch`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` .

1432
    Member Functions:
1433
        case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
1434

1435 1436 1437 1438 1439
        default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.

    Case and default functions can only be used inside the scope of Switch, as shown below:

    .. code-block:: python
1440

1441
        '''
1442 1443
        import paddle
        import paddle.fluid as fluid
1444 1445
        with fluid.layers.Switch() as switch:
            with switch.case(cond1):
1446
                i = paddle.full(shape=[1], dtype='int64', fill_value=1)
1447
            with switch.case(cond2):
1448
                i = paddle.full(shape=[1], dtype='int64', fill_value=2)
1449
            with switch.default():
1450
                i = paddle.full(shape=[1], dtype='int64', fill_value=0)
1451
        '''
Q
qiaolongfei 已提交
1452

1453 1454
    Args:
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
Q
qiaolongfei 已提交
1455 1456 1457

    Examples:
        .. code-block:: python
1458

1459
            import paddle
1460
            import paddle.fluid as fluid
Q
qiaolongfei 已提交
1461

1462
            lr = paddle.static.create_global_var(
Q
qiaolongfei 已提交
1463 1464 1465 1466 1467
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate")
1468 1469 1470 1471 1472 1473
            zero_var = paddle.full(
                shape=[1], dtype='float32', fill_value=0.0)
            one_var = paddle.full(
                shape=[1], dtype='float32', fill_value=1.0)
            two_var = paddle.full(
                shape=[1], dtype='float32', fill_value=2.0)
1474

1475
            global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
Q
qiaolongfei 已提交
1476 1477

            with fluid.layers.control_flow.Switch() as switch:
Q
qiaolongfei 已提交
1478
                with switch.case(global_step == zero_var):
1479
                    paddle.assign(input=one_var, output=lr)
Q
qiaolongfei 已提交
1480
                with switch.default():
1481
                    paddle.assign(input=two_var, output=lr)
Q
qiaolongfei 已提交
1482

1483 1484 1485 1486 1487
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
            print(res) # [array([1.], dtype=float32)]
Q
qiaolongfei 已提交
1488 1489
    """

1490 1491 1492 1493 1494 1495 1496 1497 1498
    def __init__(self, name=None):
        self.helper = LayerHelper('switch', name=name)
        self.inside_scope = False
        self.pre_not_conditions = []

    def case(self, condition):
        if not self.inside_scope:
            raise ValueError("case should be called inside with")

1499
        check_variable_and_dtype(
1500 1501 1502 1503 1504
            condition,
            'condition',
            ['bool'],
            'the member function case of fluid.layers.Switch',
        )
1505

1506 1507
        if len(self.pre_not_conditions) == 0:
            cond_block = ConditionalBlock([condition], is_scalar_condition=True)
2
201716010711 已提交
1508
            not_cond = paddle.logical_not(x=condition)
1509 1510 1511 1512
            self.pre_not_conditions.append(not_cond)
        else:
            pre_cond_num = len(self.pre_not_conditions)
            pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
1513
            new_not_cond = paddle.logical_and(
2
201716010711 已提交
1514
                x=pre_not_cond, y=paddle.logical_not(x=condition)
1515
            )
1516 1517
            self.pre_not_conditions.append(new_not_cond)
            cond_block = ConditionalBlock(
1518
                [paddle.logical_and(x=pre_not_cond, y=condition)],
1519 1520
                is_scalar_condition=True,
            )
1521 1522 1523 1524 1525 1526 1527 1528 1529

        return ConditionalBlockGuard(cond_block)

    def default(self):
        pre_cond_num = len(self.pre_not_conditions)
        if pre_cond_num == 0:
            raise ValueError("there should be at least one condition")
        cond_block = ConditionalBlock(
            [self.pre_not_conditions[pre_cond_num - 1]],
1530 1531
            is_scalar_condition=True,
        )
1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547
        return ConditionalBlockGuard(cond_block)

    def __enter__(self):
        """
        set flag that now is inside switch.block {}
        :return:
        """
        self.inside_scope = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.inside_scope = False
        if exc_type is not None:
            return False  # re-raise exception

        return True