rnn.py 56.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

F
Feiyu Chan 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
import copy
import collections
import itertools
import six
import math
import sys
import warnings
from functools import partial, reduce

import paddle
from paddle import framework
from paddle.nn import functional as F
from paddle.nn import initializer as I
from paddle.fluid.dygraph import Layer, LayerList
from paddle.fluid.layers import utils
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.data_feeder import convert_dtype
32 33

__all__ = [
F
Feiyu Chan 已提交
34 35 36 37 38 39 40 41 42
    'RNNCellBase',
    'SimpleRNNCell',
    'LSTMCell',
    'GRUCell',
    'RNN',
    'BiRNN',
    'SimpleRNN',
    'LSTM',
    'GRU',
43
]
F
Feiyu Chan 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290


def split_states(states, bidirectional=False, state_components=1):
    r"""
    Split states of RNN network into possibly nested list or tuple of
    states of each RNN cells of the RNN network.

    Arguments:
        states (Tensor|tuple|list): the concatenated states for RNN network.
            When `state_components` is 1, states in a Tensor with shape
            `(L*D, N, C)` where `L` is the number of layers of the RNN 
            network, `D` is the number of directions of the RNN network(1 
            for unidirectional RNNs and 2 for bidirectional RNNs), `N` is 
            the batch size of the input to the RNN network, `C` is the 
            hidden size of the RNN network. 

            When `state_components` is larger than 1, `states` is a tuple of 
            `state_components` Tensors that meet the requirements described 
            above. 
            
            For SimpleRNNs and GRUs, `state_components` is 1, and for LSTMs, 
            `state_components` is 2.
        bidirectional (bool): whether the state is of a bidirectional RNN 
            network. Defaults to False.
        state_components (int): the number of the components of the states. see
            `states` above. Defaults to 1.
    
    Returns:
        A nested list or tuple of RNN cell states. 
        If `bidirectional` is True, it can be indexed twice to get an RNN 
        cell state. The first index indicates the layer, the second index 
        indicates the direction.
        If `bidirectional` is False, it can be indexed once to get an RNN
        cell state. The index indicates the layer.
        Note that if `state_components` is larger than 1, an RNN cell state
        can be indexed one more time to get a tensor of shape(N, C), where 
        `N` is the batch size of the input to the RNN cell, and `C` is the
        hidden size of the RNN cell.
    """
    if state_components == 1:
        states = paddle.unstack(states)
        if not bidirectional:
            return states
        else:
            return list(zip(states[::2], states[1::2]))
    else:
        assert len(states) == state_components
        states = tuple([paddle.unstack(item) for item in states])
        if not bidirectional:
            return list(zip(*states))
        else:
            states = list(zip(*states))
            return list(zip(states[::2], states[1::2]))


def concat_states(states, bidirectional=False, state_components=1):
    r"""
    Concatenate a possibly nested list or tuple of RNN cell states into a 
    compact form.

    Arguments:
        states (list|tuple): a possibly nested list or tuple of RNN cell 
            states. 
            If `bidirectional` is True, it can be indexed twice to get an 
            RNN cell state. The first index indicates the layer, the second 
            index indicates the direction.
            If `bidirectional` is False, it can be indexed once to get an RNN
            cell state. The index indicates the layer.
            Note that if `state_components` is larger than 1, an RNN cell 
            state can be indexed one more time to get a tensor of shape(N, C), 
            where `N` is the batch size of the input to the RNN cell, and 
            `C` is the hidden size of the RNN cell. 
        bidirectional (bool): whether the state is of a bidirectional RNN 
            network. Defaults to False.
        state_components (int): the number of the components of the states. see
            `states` above. Defaults to 1.
    
    Returns:
        Concatenated states for RNN network.
        When `state_components` is 1, states in a Tensor with shape
        `(L\*D, N, C)` where `L` is the number of layers of the RNN 
        network, `D` is the number of directions of the RNN network(1 for 
        unidirectional RNNs and 2 for bidirectional RNNs), `N` is the batch 
        size of the input to the RNN network, `C` is the hidden size of the 
        RNN network.
        
    """
    if state_components == 1:
        return paddle.stack(flatten(states))
    else:
        states = flatten(states)
        componnets = []
        for i in range(state_components):
            componnets.append(states[i::state_components])
        return [paddle.stack(item) for item in componnets]


class RNNCellBase(Layer):
    r"""
    RNNCellBase is the base class for abstraction representing the calculations
    mapping the input and state to the output and new state. It is suitable to
    and mostly used in RNN.
    """

    def get_initial_states(self,
                           batch_ref,
                           shape=None,
                           dtype=None,
                           init_value=0.,
                           batch_dim_idx=0):
        r"""
        Generate initialized states according to provided shape, data type and
        value.
        Arguments:
            batch_ref (Tensor): A tensor, which shape would be used to 
                determine the batch size, which is used to generate initial 
                states. For `batch_ref`'s shape d, `d[batch_dim_idx]` is 
                treated as batch size.
            shape (list|tuple, optional): A (possibly nested structure of) shape[s], 
                where a shape is a list/tuple of integer). `-1` (for batch size) 
                will be automatically prepended if a shape does not starts with 
                it. If None, property `state_shape` will be used. Defaults to 
                None.
            dtype (str|list|tuple, optional): A (possibly nested structure of) 
                data type[s]. The structure must be same as that of `shape`, 
                except when all tensors' in states has the same data type, a 
                single data type can be used. If None and property `cell.state_shape` 
                is not available, current default floating type of paddle is 
                used. Defaults to None.
            init_value (float, optional): A float value used to initialize states. 
                Defaults to 0.
            batch_dim_idx (int, optional): An integer indicating which 
                dimension of the of `batch_ref` represents batch. Defaults to 0.
        Returns:
            init_states (Tensor|tuple|list): tensor of the provided shape and 
                dtype, or list of tensors that each satisfies the requirements,
                packed in the same structure as `shape` and `type` does.
        """
        # TODO: use inputs and batch_size
        batch_ref = flatten(batch_ref)[0]

        def _is_shape_sequence(seq):
            if sys.version_info < (3, ):
                integer_types = (
                    int,
                    long, )
            else:
                integer_types = (int, )
            """For shape, list/tuple of integer is the finest-grained objection"""
            if (isinstance(seq, list) or isinstance(seq, tuple)):
                if reduce(lambda flag, x: isinstance(x, integer_types) and flag,
                          seq, True):
                    return False
            # TODO: Add check for the illegal
            if isinstance(seq, dict):
                return True
            return (isinstance(seq, collections.Sequence) and
                    not isinstance(seq, six.string_types))

        class Shape(object):
            def __init__(self, shape):
                self.shape = shape if shape[0] == -1 else ([-1] + list(shape))

        # nested structure of shapes
        states_shapes = self.state_shape if shape is None else shape
        is_sequence_ori = utils.is_sequence
        utils.is_sequence = _is_shape_sequence
        states_shapes = map_structure(lambda shape: Shape(shape), states_shapes)
        utils.is_sequence = is_sequence_ori

        # nested structure of dtypes
        try:
            states_dtypes = self.state_dtype if dtype is None else dtype
        except NotImplementedError:
            states_dtypes = framework.get_default_dtype()
        if len(flatten(states_dtypes)) == 1:
            dtype = flatten(states_dtypes)[0]
            states_dtypes = map_structure(lambda shape: dtype, states_shapes)

        init_states = map_structure(
            lambda shape, dtype: paddle.fluid.layers.fill_constant_batch_size_like(
                input=batch_ref,
                shape=shape.shape,
                dtype=dtype,
                value=init_value,
                input_dim_idx=batch_dim_idx), states_shapes, states_dtypes)
        return init_states

    @property
    def state_shape(self):
        r"""
        Abstract method (property).
        Used to initialize states.
        A (possiblely nested structure of) shape[s], where a shape is a 
        list/tuple of integers (-1 for batch size would be automatically
        inserted into a shape if shape is not started with it).
        Not necessary to be implemented if states are not initialized by
        `get_initial_states` or the `shape` argument is provided when using
        `get_initial_states`.
        """
        raise NotImplementedError(
            "Please add implementaion for `state_shape` in the used cell.")

    @property
    def state_dtype(self):
        r"""
        Abstract method (property).
        Used to initialize states.
        A (possiblely nested structure of) data types[s]. The structure must be
        same as that of `shape`, except when all tensors' in states has the same
        data type, a signle data type can be used.
        Not necessary to be implemented if states are not initialized
        by `get_initial_states` or the `dtype` argument is provided when using
        `get_initial_states`.
        """
        raise NotImplementedError(
            "Please add implementaion for `state_dtype` in the used cell.")


class SimpleRNNCell(RNNCellBase):
    r"""
    Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it 
    computes the outputs and updates states.

    The formula used is as follows:

    .. math::
        h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh})
        y_{t} & = h_{t}
    
    where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise 
    multiplication operator.

    Please refer to `Finding Structure in Time 
    <https://crl.ucsd.edu/~elman/Papers/fsit.pdf>`_ for more details.
    
    Arguments:
        input_size (int): The input size.
        hidden_size (int): The hidden size.
        activation (str, optional): The activation in the SimpleRNN cell. 
            It can be `tanh` or `relu`. Defaults to `tanh`.
        weight_ih_attr (ParamAttr, optional): The parameter attribute for 
            `weight_ih`. Default: None.
        weight_hh_attr(ParamAttr, optional): The parameter attribute for 
            `weight_hh`. Default: None.
        bias_ih_attr (ParamAttr, optional): The parameter attribute for the 
            `bias_ih`. Default: None.
291
        bias_hh_attr (ParamAttr, optional): The parameter attribute for the 
F
Feiyu Chan 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
            `bias_hh`. Default: None.
        name (str, optional): Name for the operation (optional, default is 
            None). For more information, please refer to :ref:`api_guide_Name`.

    Parameters:
        weight_ih (Parameter): shape (hidden_size, input_size), input to hidden 
            weight, corresponding to :math:`W_{ih}` in the formula.
        weight_hh (Parameter): shape (hidden_size, hidden_size), hidden to 
            hidden weight, corresponding to :math:`W_{hh}` in the formula.
        bias_ih (Parameter): shape (hidden_size, ), input to hidden bias, 
            corresponding to :math:`b_{ih}` in the formula.
        bias_hh (Parameter): shape (hidden_size, ), hidden to hidden bias, 
            corresponding to :math:`b_{hh}` in the formula.
    
    Inputs:
        inputs (Tensor): shape `[batch_size, input_size]`, the input, 
                corresponding to :math:`x_t` in the formula.
        states (Tensor, optional): shape `[batch_size, hidden_size]`, the
            previous hidden state, corresponding to :math:`h_{t-1}` in the 
            formula. When states is None, zero state is used. Defaults to 
            None.

    Returns:
        (outputs, new_states)
        outputs (Tensor): shape `[batch_size, hidden_size]`, the output, 
            corresponding to :math:`h_{t}` in the formula.
        states (Tensor): shape `[batch_size, hidden_size]`, the new hidden 
            state, corresponding to :math:`h_{t}` in the formula.
    
    Notes:
        All the weights and bias are initialized with `Uniform(-std, std)` by 
        default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more 
        information about parameter initialization, please refer to
         :ref:`api_fluid_ParamAttr`.

    Examples:

        .. code-block:: python

            import paddle
            paddle.disable_static()

            x = paddle.randn((4, 16))
            prev_h = paddle.randn((4, 32))

            cell = paddle.nn.SimpleRNNCell(16, 32)
            y, h = cell(x, prev_h)

    """

    def __init__(self,
                 input_size,
                 hidden_size,
                 activation="tanh",
                 weight_ih_attr=None,
                 weight_hh_attr=None,
                 bias_ih_attr=None,
                 bias_hh_attr=None,
                 name=None):
        super(SimpleRNNCell, self).__init__()
        std = 1.0 / math.sqrt(hidden_size)
        self.weight_ih = self.create_parameter(
            (hidden_size, input_size),
            weight_ih_attr,
            default_initializer=I.Uniform(-std, std))
        self.weight_hh = self.create_parameter(
            (hidden_size, hidden_size),
            weight_hh_attr,
            default_initializer=I.Uniform(-std, std))
        self.bias_ih = self.create_parameter(
            (hidden_size, ),
            bias_ih_attr,
            is_bias=True,
            default_initializer=I.Uniform(-std, std))
        self.bias_hh = self.create_parameter(
            (hidden_size, ),
            bias_hh_attr,
            is_bias=True,
            default_initializer=I.Uniform(-std, std))

        self.input_size = input_size
        self.hidden_size = hidden_size
        if activation not in ["tanh", "relu"]:
            raise ValueError(
                "activation for SimpleRNNCell should be tanh or relu, "
                "but get {}".format(activation))
        self.activation = activation
        self._activation_fn = paddle.tanh \
            if activation == "tanh" \
            else F.relu

    def forward(self, inputs, states=None):
        if states is None:
            states = self.get_initial_states(inputs, self.state_shape)
        pre_h = states
        i2h = paddle.matmul(inputs, self.weight_ih, transpose_y=True)
        if self.bias_ih is not None:
            i2h += self.bias_ih
        h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
        if self.bias_hh is not None:
            h2h += self.bias_hh
        h = self._activation_fn(i2h + h2h)
        return h, h

    @property
    def state_shape(self):
        return (self.hidden_size, )


class LSTMCell(RNNCellBase):
    r"""
    Long-Short Term Memory(LSTM) RNN cell. Given the inputs and previous states, 
    it computes the outputs and updates states.

    The formula used is as follows:

    .. math::
        i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi})
        f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf})
        o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho})
        \\widetilde{c}_{t} & = \\tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg})
        c_{t} & = f_{t} \* c{t-1} + i{t} \* \\widetile{c}_{t}
        h_{t} & = o_{t} \* \\tanh(c_{t})
        y_{t} & = h_{t}

    where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise 
    multiplication operator.

    Please refer to `An Empirical Exploration of Recurrent Network Architectures
    <http://proceedings.mlr.press/v37/jozefowicz15.pdf>`_ for more details.

    Arguments:
        input_size (int): The input size.
        hidden_size (int): The hidden size.
        weight_ih_attr(ParamAttr, optional): The parameter attribute for 
            `weight_ih`. Default: None.
        weight_hh_attr(ParamAttr, optional): The parameter attribute for 
            `weight_hh`. Default: None.
        bias_ih_attr (ParamAttr, optional): The parameter attribute for the 
            `bias_ih`. Default: None.
432
        bias_hh_attr (ParamAttr, optional): The parameter attribute for the 
F
Feiyu Chan 已提交
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
            `bias_hh`. Default: None.
        name (str, optional): Name for the operation (optional, default is 
            None). For more information, please refer to :ref:`api_guide_Name`.

    Parameters:
        weight_ih (Parameter): shape (4 * hidden_size, input_size), input to 
            hidden weight, which corresponds to the concatenation of
             :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula.
        weight_hh (Parameter): shape (4 * hidden_size, hidden_size), hidden to 
            hidden weight, which corresponds to the concatenation of
             :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula.
        bias_ih (Parameter): shape (4 * hidden_size, ), input to hidden bias, 
            which corresponds to the concatenation of
             :math:`b_{ii}, b_{if}, b_{ig}, b_{io}` in the formula.
        bias_hh (Parameter): shape (4 * hidden_size, ), hidden to hidden bias, 
            which corresponds to the concatenation of
             :math:`b_{hi}, b_{hf}, b_{hg}, b_{ho}` in the formula.

    Inputs:
        inputs (Tensor): shape `[batch_size, input_size]`, the input, 
            corresponding to :math:`x_t` in the formula.
        states (tuple, optional): a tuple of two tensors, each of shape 
            `[batch_size, hidden_size]`, the previous hidden state, 
            corresponding to :math:`h_{t-1}, c_{t-1}` in the formula. 
            When states is None, zero state is used. Defaults to None.

    Returns:
        (outputs, new_states)
        outputs (Tensor): shape `[batch_size, hidden_size]`, the output, 
            corresponding to :math:`h_{t}` in the formula.
        states (tuple): a tuple of two tensors, each of shape 
            `[batch_size, hidden_size]`, the new hidden states,
            corresponding to :math:`h_{t}, c{t}` in the formula.

    Notes:
        All the weights and bias are initialized with `Uniform(-std, std)` by 
        default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more 
        information about parameter initialization, please refer to
         :ref:`api_fluid_ParamAttr`.

    Examples:

        .. code-block:: python

            import paddle
            paddle.disable_static()

            x = paddle.randn((4, 16))
            prev_h = paddle.randn((4, 32))
            prev_c = paddle.randn((4, 32))

            cell = paddle.nn.LSTMCell(16, 32)
            y, (h, c) = cell(x, (prev_h, prev_c))

    """

    def __init__(self,
                 input_size,
                 hidden_size,
                 weight_ih_attr=None,
                 weight_hh_attr=None,
                 bias_ih_attr=None,
                 bias_hh_attr=None,
                 name=None):
        super(LSTMCell, self).__init__()
        std = 1.0 / math.sqrt(hidden_size)
        self.weight_ih = self.create_parameter(
            (4 * hidden_size, input_size),
            weight_ih_attr,
            default_initializer=I.Uniform(-std, std))
        self.weight_hh = self.create_parameter(
            (4 * hidden_size, hidden_size),
            weight_hh_attr,
            default_initializer=I.Uniform(-std, std))
        self.bias_ih = self.create_parameter(
            (4 * hidden_size, ),
            bias_ih_attr,
            is_bias=True,
            default_initializer=I.Uniform(-std, std))
        self.bias_hh = self.create_parameter(
            (4 * hidden_size, ),
            bias_hh_attr,
            is_bias=True,
            default_initializer=I.Uniform(-std, std))

        self.hidden_size = hidden_size
        self.input_size = input_size
        self._gate_activation = F.sigmoid
        self._activation = paddle.tanh

    def forward(self, inputs, states=None):
        if states is None:
            states = self.get_initial_states(inputs, self.state_shape)
        pre_hidden, pre_cell = states
        gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True)
        if self.bias_ih is not None:
            gates = gates + self.bias_ih
        gates += paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
        if self.bias_hh is not None:
            gates = gates + self.bias_hh

        chunked_gates = paddle.split(gates, num_or_sections=4, axis=-1)

        i = self._gate_activation(chunked_gates[0])
        f = self._gate_activation(chunked_gates[1])
        o = self._gate_activation(chunked_gates[3])
        c = f * pre_cell + i * self._activation(chunked_gates[2])
        h = o * self._activation(c)

        return h, (h, c)

    @property
    def state_shape(self):
        r"""
        The `state_shape` of LSTMCell is a tuple with two shapes: 
        `((hidden_size, ), (hidden_size,))`. (-1 for batch size would be 
        automatically inserted into shape). These two shapes correspond 
        to :math:`h_{t-1}` and :math:`c_{t-1}` separately.
        """
        return ((self.hidden_size, ), (self.hidden_size, ))


class GRUCell(RNNCellBase):
    r"""
    Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, 
    it computes the outputs and updates states.

    The formula for GRU used is as follows:

    .. math::

        r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr})
        z_{t} & = \sigma(W_{iz)x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz})
        \\widetilde{h}_{t} & = \\tanh(W_{ic)x_{t} + b_{ic} + r_{t} \* (W_{hc}x_{t} + b{hc}))
        h_{t} & = z_{t} \* h_{t-1} + (1 - z_{t}) \* \\widetilde{h}_{t}
        y_{t} & = h_{t}
    
    where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise 
    multiplication operator.

    Please refer to `An Empirical Exploration of Recurrent Network Architectures
    <http://proceedings.mlr.press/v37/jozefowicz15.pdf>`_ for more details.

    Parameters:
        input_size (int): The input size..
        hidden_size (int): The hidden size.
        weight_ih_attr(ParamAttr, optional): The parameter attribute for 
            `weight_ih`. Default: None.
        weight_hh_attr(ParamAttr, optional): The parameter attribute for 
            `weight_hh`. Default: None.
        bias_ih_attr (ParamAttr, optional): The parameter attribute for the 
            `bias_ih`. Default: None.
585
        bias_hh_attr (ParamAttr, optional): The parameter attribute for the 
F
Feiyu Chan 已提交
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962
            `bias_hh`. Default: None.
        name (str, optional): Name for the operation (optional, default is 
            None). For more information, please refer to :ref:`api_guide_Name`.

    Parameters:
        weight_ih (Parameter): shape (3 * hidden_size, input_size), input to 
            hidden weight, which corresponds to the concatenation of
             :math:`W_{ir}, W_{iz}, W_{ic}` in the formula.
        weight_hh (Parameter): shape (3 * hidden_size, hidden_size), hidden to 
            hidden weight, which corresponds to the concatenation of
             :math:`W_{hr}, W_{hz}, W_{hc}` in the formula.
        bias_ih (Parameter): shape (3 * hidden_size, ), input to hidden bias, 
            which corresponds to the concatenation of
             :math:`b_{ir}, b_{iz}, b_{ic}` in the formula.
        bias_hh (Parameter): shape (3 * hidden_size, ), hidden to hidden bias, 
            which corresponds to the concatenation of
             :math:`b_{hr}, b_{hz}, b_{hc}` in the formula.

    Inputs:
        inputs (Tensor): A tensor with shape `[batch_size, input_size]`,
            corresponding to :math:`x_t` in the formula.
        states (Tensor): A tensor with shape `[batch_size, hidden_size]`.
            corresponding to :math:`h_{t-1}` in the formula.

    Returns:
        (outputs, new_states)
        outputs (Tensor): shape `[batch_size, hidden_size]`, the output, 
            corresponding to :math:`h_{t}` in the formula.
        states (Tensor): shape `[batch_size, hidden_size]`, the new hidden 
            state, corresponding to :math:`h_{t}` in the formula.
    
    Notes:
        All the weights and bias are initialized with `Uniform(-std, std)` by 
        default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more 
        information about parameter initialization, please refer to
         :ref:`api_fluid_ParamAttr`.

    Examples:

        .. code-block:: python

            import paddle
            paddle.disable_static()

            x = paddle.randn((4, 16))
            prev_h = paddle.randn((4, 32))

            cell = paddle.nn.GRUCell(16, 32)
            y, h = cell(x, prev_h)

    """

    def __init__(self,
                 input_size,
                 hidden_size,
                 weight_ih_attr=None,
                 weight_hh_attr=None,
                 bias_ih_attr=None,
                 bias_hh_attr=None,
                 name=None):
        super(GRUCell, self).__init__()
        std = 1.0 / math.sqrt(hidden_size)
        self.weight_ih = self.create_parameter(
            (3 * hidden_size, input_size),
            weight_ih_attr,
            default_initializer=I.Uniform(-std, std))
        self.weight_hh = self.create_parameter(
            (3 * hidden_size, hidden_size),
            weight_hh_attr,
            default_initializer=I.Uniform(-std, std))
        self.bias_ih = self.create_parameter(
            (3 * hidden_size, ),
            bias_ih_attr,
            is_bias=True,
            default_initializer=I.Uniform(-std, std))
        self.bias_hh = self.create_parameter(
            (3 * hidden_size, ),
            bias_hh_attr,
            is_bias=True,
            default_initializer=I.Uniform(-std, std))

        self.hidden_size = hidden_size
        self.input_size = input_size
        self._gate_activation = F.sigmoid
        self._activation = paddle.tanh

    def forward(self, inputs, states=None):
        if states is None:
            states = self.get_initial_states(inputs, self.state_shape)

        pre_hidden = states
        x_gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True)
        if self.bias_ih is not None:
            x_gates = x_gates + self.bias_ih
        h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
        if self.bias_hh is not None:
            h_gates = h_gates + self.bias_hh

        x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
        h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)

        r = self._gate_activation(x_r + h_r)
        z = self._gate_activation(x_z + h_z)
        c = self._activation(x_c + r * h_c)  # apply reset gate after mm
        h = (pre_hidden - c) * z + c

        return h, h

    @property
    def state_shape(self):
        r"""
        The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
        size would be automatically inserted into shape). The shape corresponds
        to the shape of :math:`h_{t-1}`.
        """
        return (self.hidden_size, )


class RNN(Layer):
    r"""
    Wrapper for RNN, which creates a recurrent neural network with an RNN cell. 
    It performs :code:`cell.forward()` repeatedly until reaches to the maximum 
    length of `inputs`.

    Arguments:
        cell(RNNCellBase): An instance of `RNNCellBase`.
        is_reverse (bool, optional): Indicate whether to calculate in the reverse
            order of input sequences. Defaults to False.
        time_major (bool): Whether the first dimension of the input means the
            time steps. Defaults to False.

    Inputs:
        inputs (Tensor): A (possibly nested structure of) tensor[s]. The input 
            sequences. 
            If time major is True, the shape is `[batch_size, time_steps, input_size]`
            If time major is False, the shape is [time_steps, batch_size, input_size]`
            where `input_size` is the input size of the cell.
        initial_states (Tensor|list|tuple, optional): Tensor of a possibly 
            nested structure of tensors, representing the initial state for 
            the rnn cell. If not provided, `cell.get_initial_states` would be 
            called to produce the initial states. Defaults to None.
        sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 
            or int32. The valid lengths of input sequences. Defaults to None.
            If `sequence_length` is not None, the inputs are treated as 
            padded sequences. In each input sequence, elements whose time step 
            index are not less than the valid length are treated as paddings.
        **kwargs: Additional keyword arguments to pass to `forward` of the cell. 

    Returns:
        (outputs, final_states)
        outputs (Tensor|list|tuple): the output sequences.
            If `time_major` is True, the shape is 
            `[time_steps, batch_size, hidden_size]`, else 
            `[batch_size, time_steps, hidden_size]`.
        final_states (Tensor|list|tuple): final states of the cell. Tensor or 
            a possibly nested structure of tensors which has the same structure 
            with intial state. Each tensor in final states has the same shape 
            and dtype as the corresponding tensor in initial states.
    
    Notes:
        This class is a low level API for wrapping rnn cell into a RNN network.
        Users should take care of the state of the cell. If `initial_states` is 
        passed to the `forward` method, make sure that it satisfies the 
        requirements of the cell.

    Examples:

        .. code-block:: python

            import paddle
            paddle.disable_static()

            inputs = paddle.rand((4, 23, 16))
            prev_h = paddle.randn((4, 32))

            cell = paddle.nn.SimpleRNNCell(16, 32)
            rnn = paddle.nn.RNN(cell)
            outputs, final_states = rnn(inputs, prev_h)

    """

    def __init__(self, cell, is_reverse=False, time_major=False):
        super(RNN, self).__init__()
        self.cell = cell
        if not hasattr(self.cell, "call"):
            # for non-dygraph mode, `rnn` api uses cell.call
            self.cell.call = self.cell.forward
        self.is_reverse = is_reverse
        self.time_major = time_major

    def forward(self,
                inputs,
                initial_states=None,
                sequence_length=None,
                **kwargs):
        final_outputs, final_states = F.rnn(self.cell,
                                            inputs,
                                            initial_states=initial_states,
                                            sequence_length=sequence_length,
                                            time_major=self.time_major,
                                            is_reverse=self.is_reverse,
                                            **kwargs)
        return final_outputs, final_states


class BiRNN(Layer):
    r"""
    Wrapper for bidirectional RNN, which builds a bidiretional RNN given the 
    forward rnn cell and backward rnn cell. A BiRNN applies forward RNN and 
    backward RNN with coresponding cells separately and concats the outputs 
    along the last axis.

    Arguments:
        cell_fw (RNNCellBase): A RNNCellBase instance used for forward RNN.
        cell_bw (RNNCellBase): A RNNCellBase instance used for backward RNN.
        time_major (bool): Whether the first dimension of the input means the
            time steps. Defaults to False.

    Inputs:
        inputs (Tensor): the input sequences of both RNN. 
            If time_major is True, the shape of is 
            `[time_steps, batch_size, input_size]`, else the shape is
            `[batch_size, time_steps, input_size]`, where input_size is the 
            input size of both cells.
        initial_states (list|tuple, optional): A tuple/list of the initial 
            states of the forward cell and backward cell. Defaults to None.
            If not provided, `cell.get_initial_states` would be called to 
            produce the initial states for each cell. Defaults to None.
        sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 
            or int32. The valid lengths of input sequences. Defaults to None.
            If `sequence_length` is not None, the inputs are treated as 
            padded sequences. In each input sequence, elements whose time step 
            index are not less than the valid length are treated as paddings.
        **kwargs: Additional keyword arguments. Arguments passed to `forward` 
            for each cell.

    Outputs:
        (outputs, final_states)
        outputs (Tensor): the outputs of the bidirectional RNN. It is the 
            concatenation of the outputs from the forward RNN and backward 
            RNN along the last axis. 
            If time major is True, the shape is `[time_steps, batch_size, size]`,
            else the shape is `[batch_size, time_steps, size]`, where size is
            `cell_fw.hidden_size + cell_bw.hidden_size`.
        final_states (tuple): A tuple of the final states of the forward 
            cell and backward cell. 

    Notes:
        This class is a low level API for wrapping rnn cells into a BiRNN 
        network. Users should take care of the states of the cells. 
        If `initial_states` is passed to the `forward` method, make sure that 
        it satisfies the requirements of the cells.

    Examples:

        .. code-block:: python

            import paddle
            paddle.disable_static()

            cell_fw = paddle.nn.LSTMCell(16, 32)
            cell_bw = paddle.nn.LSTMCell(16, 32)
            rnn = paddle.nn.BiRNN(cell_fw, cell_bw)

            inputs = paddle.rand((2, 23, 16))
            outputs, final_states = rnn(inputs)

    """

    def __init__(self, cell_fw, cell_bw, time_major=False):
        super(BiRNN, self).__init__()
        self.cell_fw = cell_fw
        self.cell_bw = cell_bw
        if cell_fw.input_size != cell_bw.input_size:
            raise ValueError("input size of forward cell({}) does not equals"
                             "that of backward cell({})".format(
                                 cell_fw.input_size, cell_bw.input_size))
        for cell in [self.cell_fw, self.cell_bw]:
            if not hasattr(cell, "call"):
                # for non-dygraph mode, `rnn` api uses cell.call
                cell.call = cell.forward
        self.time_major = time_major

    def forward(self,
                inputs,
                initial_states=None,
                sequence_length=None,
                **kwargs):
        if isinstance(initial_states, (list, tuple)):
            assert len(initial_states) == 2, \
                "length of initial_states should be 2 when it is a list/tuple"

        outputs, final_states = F.birnn(self.cell_fw, self.cell_bw, inputs,
                                        initial_states, sequence_length,
                                        self.time_major, **kwargs)
        return outputs, final_states


class RNNMixin(LayerList):
    r"""
    A Mixin class for RNN networks. It provides `forward` method for SimpleRNN,
    LSTM and GRU.
    """

    def forward(self, inputs, initial_states=None, sequence_length=None):
        batch_index = 1 if self.time_major else 0
        dtype = inputs.dtype
        if initial_states is None:
            state_shape = (self.num_layers * self.num_directions, -1,
                           self.hidden_size)
            if self.state_components == 1:
                initial_states = paddle.fluid.layers.fill_constant_batch_size_like(
                    inputs, state_shape, dtype, 0, batch_index, 1)
            else:
                initial_states = tuple([
                    paddle.fluid.layers.fill_constant_batch_size_like(
                        inputs, state_shape, dtype, 0, batch_index, 1)
                    for _ in range(self.state_components)
                ])

        states = split_states(initial_states, self.num_directions == 2,
                              self.state_components)
        final_states = []

        for i, rnn_layer in enumerate(self):
            if i > 0:
                inputs = F.dropout(
                    inputs,
                    self.dropout,
                    training=self.training,
                    mode="upscale_in_train")
            outputs, final_state = rnn_layer(inputs, states[i], sequence_length)
            final_states.append(final_state)
            inputs = outputs

        final_states = concat_states(final_states, self.num_directions == 2,
                                     self.state_components)
        return outputs, final_states


class SimpleRNN(RNNMixin):
    r"""
    Multilayer Elman network(SimpleRNN). It takes input sequences and initial 
    states as inputs, and returns the output sequences and the final states.

    Each layer inside the SimpleRNN maps the input sequences and initial states 
    to the output sequences and final states in the following manner: at each 
    step, it takes step inputs(:math:`x_{t}`) and previous 
    states(:math:`h_{t-1}`) as inputs, and returns step outputs(:math:`y_{t}`)
    and new states(:math:`h_{t}`).

    .. math::

        h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh})
        y_{t} & = h_{t}
    
    where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise 
    multiplication operator.

    Arguments:
        input_size (int): The input size for the first layer's cell.
        hidden_size (int): The hidden size for each layer's cell.
        num_layers (int, optional): Number of layers. Defaults to 1.
        activation (str, optional): The activation in each SimpleRNN cell. It can be 
            `tanh` or `relu`. Defaults to `tanh`.
        direction (str, optional): The direction of the network. It can be "forward", 
            "backward" and "bidirectional". Defaults to "forward".
        dropout (float, optional): The droput probability. Dropout is applied to the 
            input of each layer except for the first layer. Defaults to 0.
        time_major (bool, optional): Whether the first dimension of the input means the
            time steps. Defaults to False.
        weight_ih_attr (ParamAttr, optional): The parameter attribute for 
            `weight_ih` of each cell. Defaults to None.
        weight_hh_attr (ParamAttr, optional): The parameter attribute for 
            `weight_hh` of each cell. Defaults to None.
        bias_ih_attr (ParamAttr, optional): The parameter attribute for the 
            `bias_ih` of each cells. Defaults to None.
963
        bias_hh_attr (ParamAttr, optional): The parameter attribute for the 
F
Feiyu Chan 已提交
964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105
            `bias_hh` of each cells. Defaults to None.
        name (str, optional): Name for the operation (optional, default is 
            None). For more information, please refer to :ref:`api_guide_Name`.

    Inputs:
        inputs (Tensor): the input sequence. 
            If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`,
            else, the shape is `[batch_size, time_steps, hidden_size]`.
        initial_states (Tensor, optional): the initial state. The shape is
            `[num_lauers * num_directions, batch_size, hidden_size]`. 
            If initial_state is not given, zero initial states are used.
        sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 
            or int32. The valid lengths of input sequences. Defaults to None.
            If `sequence_length` is not None, the inputs are treated as 
            padded sequences. In each input sequence, elements whose time step 
            index are not less than the valid length are treated as paddings.

    Returns:
        (outputs, final_states)
        outputs (Tensor): the output sequence. 
            If `time_major` is True, the shape is 
            `[time_steps, batch_size, num_directions * hidden_size]`,
            else, the shape is 
            `[batch_size, time_steps, num_directions * hidden_size]`.
            Note that `num_directions` is 2 if direction is "bidirectional" 
            else 1.
        final_states (Tensor): final states. The shape is
            `[num_lauers * num_directions, batch_size, hidden_size]`.
            Note that `num_directions` is 2 if direction is "bidirectional" 
            else 1.

    Examples:

        .. code-block:: python

            import paddle
            paddle.disable_static()

            rnn = paddle.nn.SimpleRNN(16, 32, 2)

            x = paddle.randn((4, 23, 16))
            prev_h = paddle.randn((2, 4, 32))
            y, h = rnn(x, prev_h)

    """

    def __init__(self,
                 input_size,
                 hidden_size,
                 num_layers=1,
                 activation="tanh",
                 direction="forward",
                 dropout=0.,
                 time_major=False,
                 weight_ih_attr=None,
                 weight_hh_attr=None,
                 bias_ih_attr=None,
                 bias_hh_attr=None,
                 name=None):
        super(SimpleRNN, self).__init__()

        if direction in ["forward", "backward"]:
            is_reverse = direction == "backward"
            cell = SimpleRNNCell(input_size, hidden_size, activation,
                                 weight_ih_attr, weight_hh_attr, bias_ih_attr,
                                 bias_hh_attr)
            self.append(RNN(cell, is_reverse, time_major))
            for i in range(1, num_layers):
                cell = SimpleRNNCell(hidden_size, hidden_size, activation,
                                     weight_ih_attr, weight_hh_attr,
                                     bias_ih_attr, bias_hh_attr)
                self.append(RNN(cell, is_reverse, time_major))
        elif direction == "bidirectional":
            cell_fw = SimpleRNNCell(input_size, hidden_size, activation,
                                    weight_ih_attr, weight_hh_attr,
                                    bias_ih_attr, bias_hh_attr)
            cell_bw = SimpleRNNCell(input_size, hidden_size, activation,
                                    weight_ih_attr, weight_hh_attr,
                                    bias_ih_attr, bias_hh_attr)
            self.append(BiRNN(cell_fw, cell_bw, time_major))
            for i in range(1, num_layers):
                cell_fw = SimpleRNNCell(
                    2 * hidden_size, hidden_size, activation, weight_ih_attr,
                    weight_hh_attr, bias_ih_attr, bias_hh_attr)
                cell_bw = SimpleRNNCell(
                    2 * hidden_size, hidden_size, activation, weight_ih_attr,
                    weight_hh_attr, bias_ih_attr, bias_hh_attr)
                self.append(BiRNN(cell_fw, cell_bw, time_major))
        else:
            raise ValueError(
                "direction should be forward, backward or bidirectional, "
                "received direction = {}".format(direction))

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.num_directions = 2 if direction == "bidirectional" else 1
        self.time_major = time_major
        self.num_layers = num_layers
        self.state_components = 1


class LSTM(RNNMixin):
    r"""
    Multilayer LSTM. It takes a sequence and an initial state as inputs, and 
    returns the output sequences and the final states.

    Each layer inside the LSTM maps the input sequences and initial states 
    to the output sequences and final states in the following manner: at each 
    step, it takes step inputs(:math:`x_{t}`) and previous 
    states(:math:`h_{t-1}, c_{t-1}`) as inputs, and returns step 
    outputs(:math:`y_{t}`) and new states(:math:`h_{t}, c_{t}`).

    .. math::

        i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi})
        f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf})
        o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho})
        \\widetilde{c}_{t} & = \\tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg})
        c_{t} & = f_{t} \* c{t-1} + i{t} \* \\widetile{c}_{t}
        h_{t} & = o_{t} \* \\tanh(c_{t})
        y_{t} & = h_{t}

    where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise 
    multiplication operator.

    Arguments:
        input_size (int): The input size for the first layer's cell.
        hidden_size (int): The hidden size for each layer's cell.
        num_layers (int, optional): Number of layers. Defaults to 1.
        direction (str, optional): The direction of the network. It can be 
            "forward", "backward" and "bidirectional". Defaults to "forward".
        dropout (float, optional): The droput probability. Dropout is applied 
            to the input of each layer except for the first layer. Defaults to 0.
        time_major (bool, optional): Whether the first dimension of the input 
            means the time steps. Defaults to False.
        weight_ih_attr (ParamAttr, optional): The parameter attribute for 
            `weight_ih` of each cell. Default: None.
        weight_hh_attr (ParamAttr, optional): The parameter attribute for 
            `weight_hh` of each cell. Default: None.
        bias_ih_attr (ParamAttr, optional): The parameter attribute for the 
            `bias_ih` of each cells. Default: None.
1106
        bias_hh_attr (ParamAttr, optional): The parameter attribute for the 
F
Feiyu Chan 已提交
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241
            `bias_hh` of each cells. Default: None.
        name (str, optional): Name for the operation (optional, default is 
            None). For more information, please refer to :ref:`api_guide_Name`.

    Inputs:
        inputs (Tensor): the input sequence. 
            If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`,
            else, the shape is `[batch_size, time_steps, hidden_size]`.
        initial_states (tuple, optional): the initial state, a tuple of (h, c), 
            the shape of each is `[num_lauers * num_directions, batch_size, hidden_size]`. 
            If initial_state is not given, zero initial states are used.
        sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 
            or int32. The valid lengths of input sequences. Defaults to None.
            If `sequence_length` is not None, the inputs are treated as 
            padded sequences. In each input sequence, elements whos time step 
            index are not less than the valid length are treated as paddings.

    Returns:
        (outputs, final_states)
        outputs (Tensor): the output sequence. 
            If `time_major` is True, the shape is 
            `[time_steps, batch_size, num_directions * hidden_size]`, 
            If `time_major` is False, the shape is 
            `[batch_size, time_steps, num_directions * hidden_size]`. 
            Note that `num_directions` is 2 if direction is "bidirectional" 
            else 1. 
        final_states (Tensor): the final state, a tuple of two tensors, h and c. 
            The shape of each is 
            `[num_lauers * num_directions, batch_size, hidden_size]`. 
            Note that `num_directions` is 2 if direction is "bidirectional" 
            else 1.

    Examples:
    
        .. code-block:: python

            import paddle
            paddle.disable_static()

            rnn = paddle.nn.LSTM(16, 32, 2)

            x = paddle.randn((4, 23, 16))
            prev_h = paddle.randn((2, 4, 32))
            prev_c = paddle.randn((2, 4, 32))
            y, (h, c) = rnn(x, (prev_h, prev_c))

    """

    def __init__(self,
                 input_size,
                 hidden_size,
                 num_layers=1,
                 direction="forward",
                 dropout=0.,
                 time_major=False,
                 weight_ih_attr=None,
                 weight_hh_attr=None,
                 bias_ih_attr=None,
                 bias_hh_attr=None,
                 name=None):
        super(LSTM, self).__init__()

        if direction in ["forward", "backward"]:
            is_reverse = direction == "backward"
            cell = LSTMCell(input_size, hidden_size, weight_ih_attr,
                            weight_hh_attr, bias_ih_attr, bias_hh_attr)
            self.append(RNN(cell, is_reverse, time_major))
            for i in range(1, num_layers):
                cell = LSTMCell(hidden_size, hidden_size, weight_ih_attr,
                                weight_hh_attr, bias_ih_attr, bias_hh_attr)
                self.append(RNN(cell, is_reverse, time_major))
        elif direction == "bidirectional":
            cell_fw = LSTMCell(input_size, hidden_size, weight_ih_attr,
                               weight_hh_attr, bias_ih_attr, bias_hh_attr)
            cell_bw = LSTMCell(input_size, hidden_size, weight_ih_attr,
                               weight_hh_attr, bias_ih_attr, bias_hh_attr)
            self.append(BiRNN(cell_fw, cell_bw, time_major))
            for i in range(1, num_layers):
                cell_fw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr,
                                   weight_hh_attr, bias_ih_attr, bias_hh_attr)
                cell_bw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr,
                                   weight_hh_attr, bias_ih_attr, bias_hh_attr)
                self.append(BiRNN(cell_fw, cell_bw, time_major))
        else:
            raise ValueError(
                "direction should be forward, backward or bidirectional, "
                "received direction = {}".format(direction))

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.num_directions = 2 if direction == "bidirectional" else 1
        self.time_major = time_major
        self.num_layers = num_layers
        self.state_components = 2


class GRU(RNNMixin):
    r"""
    Multilayer GRU. It takes input sequencse and initial states as inputs, and 
    returns the output sequences and the final states.

    Each layer inside the GRU maps the input sequences and initial states 
    to the output sequences and final states in the following manner: at each 
    step, it takes step inputs(:math:`x_{t}`) and previous 
    states(:math:`h_{t-1}`) as inputs, and returns step outputs(:math:`y_{t}`) 
    and new states(:math:`h_{t}`).

    .. math::

        r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr})
        z_{t} & = \sigma(W_{iz)x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz})
        \\widetilde{h}_{t} & = \\tanh(W_{ic)x_{t} + b_{ic} + r_{t} \* (W_{hc}x_{t} + b{hc}))
        h_{t} & = z_{t} \* h_{t-1} + (1 - z_{t}) \* \\widetilde{h}_{t}
        y_{t} & = h_{t}

    where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise 
    multiplication operator.

    Arguments:
        input_size (int): The input size for the first layer's cell.
        hidden_size (int): The hidden size for each layer's cell.
        num_layers (int, optional): Number of layers. Defaults to 1.
        direction (str, optional): The direction of the network. It can be 
            "forward", "backward" and "bidirectional". Defaults to "forward".
        dropout (float, optional): The droput probability. Dropout is applied 
            to the input of each layer except for the first layer. Defaults to 0.
        time_major (bool, optional): Whether the first dimension of the input 
            means the time steps. Defaults to False.
        weight_ih_attr (ParamAttr, optional): The parameter attribute for 
            `weight_ih` of each cell. Default: None.
        weight_hh_attr (ParamAttr, optional): The parameter attribute for 
            `weight_hh` of each cell. Default: None.
        bias_ih_attr (ParamAttr, optional): The parameter attribute for the 
            `bias_ih` of each cells. Default: None.
1242
        bias_hh_attr (ParamAttr, optional): The parameter attribute for the 
F
Feiyu Chan 已提交
1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336
            `bias_hh` of each cells. Default: None.
        name (str, optional): Name for the operation (optional, default is 
            None). For more information, please refer to :ref:`api_guide_Name`.

    Inputs:
        inputs (Tensor): the input sequence. 
            If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`,
            else, the shape is `[batch_size, time_steps, hidden_size]`.
        initial_states (Tensor, optional): the initial state. The shape is
            `[num_lauers * num_directions, batch_size, hidden_size]`. 
            If initial_state is not given, zero initial states are used. 
            Defaults to None.
        sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 
            or int32. The valid lengths of input sequences. Defaults to None.
            If `sequence_length` is not None, the inputs are treated as 
            padded sequences. In each input sequence, elements whos time step 
            index are not less than the valid length are treated as paddings.

    Returns:
        (outputs, final_states)
        outputs (Tensor): the output sequence. 
            If `time_major` is True, the shape is 
            `[time_steps, batch_size, num_directions * hidden_size]`,
            else, the shape is 
            `[batch_size, time_steps, num_directions * hidden_size]`.
            Note that `num_directions` is 2 if direction is "bidirectional" 
            else 1.
        final_states (Tensor): final states. The shape is
            `[num_lauers * num_directions, batch_size, hidden_size]`.
            Note that `num_directions` is 2 if direction is "bidirectional" 
            else 1.

    Examples:

        .. code-block:: python

            import paddle
            paddle.disable_static()

            rnn = paddle.nn.GRU(16, 32, 2)

            x = paddle.randn((4, 23, 16))
            prev_h = paddle.randn((2, 4, 32))
            y, h = rnn(x, prev_h)

    """

    def __init__(self,
                 input_size,
                 hidden_size,
                 num_layers=1,
                 direction="forward",
                 dropout=0.,
                 time_major=False,
                 weight_ih_attr=None,
                 weight_hh_attr=None,
                 bias_ih_attr=None,
                 bias_hh_attr=None,
                 name=None):
        super(GRU, self).__init__()

        if direction in ["forward", "backward"]:
            is_reverse = direction == "backward"
            cell = GRUCell(input_size, hidden_size, weight_ih_attr,
                           weight_hh_attr, bias_ih_attr, bias_hh_attr)
            self.append(RNN(cell, is_reverse, time_major))
            for i in range(1, num_layers):
                cell = GRUCell(hidden_size, hidden_size, weight_ih_attr,
                               weight_hh_attr, bias_ih_attr, bias_hh_attr)
                self.append(RNN(cell, is_reverse, time_major))
        elif direction == "bidirectional":
            cell_fw = GRUCell(input_size, hidden_size, weight_ih_attr,
                              weight_hh_attr, bias_ih_attr, bias_hh_attr)
            cell_bw = GRUCell(input_size, hidden_size, weight_ih_attr,
                              weight_hh_attr, bias_ih_attr, bias_hh_attr)
            self.append(BiRNN(cell_fw, cell_bw, time_major))
            for i in range(1, num_layers):
                cell_fw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr,
                                  weight_hh_attr, bias_ih_attr, bias_hh_attr)
                cell_bw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr,
                                  weight_hh_attr, bias_ih_attr, bias_hh_attr)
                self.append(BiRNN(cell_fw, cell_bw, time_major))
        else:
            raise ValueError(
                "direction should be forward, backward or bidirectional, "
                "received direction = {}".format(direction))

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.num_directions = 2 if direction == "bidirectional" else 1
        self.time_major = time_major
        self.num_layers = num_layers
        self.state_components = 1