rnn.py 18.4 KB
Newer Older
X
Xing Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
Xing Wu 已提交
15 16
from . import Layer
from ..layers import sigmoid, tanh, concat, fill_constant, matmul, elementwise_add, elementwise_mul, split
X
Xing Wu 已提交
17 18 19 20 21 22
import copy

__all__ = ['LSTMCell', 'GRUCell']


class LSTMCell(Layer):
23
    r"""
X
Xing Wu 已提交
24 25 26
    LSTMCell implementation using basic operators.
    There are two LSTMCell version, the default one is compatible with CUDNN LSTM implementation.
    The algorithm can be described as the equations below.
X
Xing Wu 已提交
27

X
Xing Wu 已提交
28
        .. math::
X
Xing Wu 已提交
29

X
Xing Wu 已提交
30
            i_t &= sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i)
X
Xing Wu 已提交
31

X
Xing Wu 已提交
32
            f_t &= sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f)
X
Xing Wu 已提交
33

X
Xing Wu 已提交
34
            o_t &= sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o)
X
Xing Wu 已提交
35

X
Xing Wu 已提交
36
            \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c)
X
Xing Wu 已提交
37

X
Xing Wu 已提交
38
            c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t}
X
Xing Wu 已提交
39

X
Xing Wu 已提交
40
            h_t &= o_t \\odot tanh(c_t)
X
Xing Wu 已提交
41

X
Xing Wu 已提交
42 43
    The other LSTMCell version is compatible with the BasicLSTMUnit used in static graph.
    The algorithm can be described as the equations below.
X
Xing Wu 已提交
44 45 46

        .. math::

X
Xing Wu 已提交
47
            i_t &= sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
X
Xing Wu 已提交
48

X
Xing Wu 已提交
49
            f_t &= sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
X
Xing Wu 已提交
50

X
Xing Wu 已提交
51
            o_t &= sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
X
Xing Wu 已提交
52

X
Xing Wu 已提交
53
            \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
X
Xing Wu 已提交
54

X
Xing Wu 已提交
55
            c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t}
X
Xing Wu 已提交
56

X
Xing Wu 已提交
57
            h_t &= o_t \\odot tanh(c_t)
X
Xing Wu 已提交
58

X
Xing Wu 已提交
59 60 61 62 63
    Args:
        hidden_size (integer): The hidden size used in the Cell.
        input_size (integer): The input size used in the Cell.
        param_attr(ParamAttr|None): The parameter attribute for the learnable
            weight matrix. Note:
X
Xing Wu 已提交
64
            If it is set to None or one attribute of ParamAttr, LSTMCell will
X
Xing Wu 已提交
65 66 67
            create ParamAttr as param_attr. If the Initializer of the param_attr
            is not set, the parameter is initialized with Xavier. Default: None.
        bias_attr (ParamAttr|None): The parameter attribute for the bias
X
Xing Wu 已提交
68 69
            of LSTMCell.
            If it is set to None or one attribute of ParamAttr, LSTMCell will 
X
Xing Wu 已提交
70 71 72 73 74 75 76 77 78
            create ParamAttr as bias_attr. If the Initializer of the bias_attr
            is not set, the bias is initialized as zero. Default: None.
        gate_activation (function|None): The activation function for gates (actGate).
                                  Default: 'fluid.layers.sigmoid'
        activation (function|None): The activation function for cells (actNode).
                             Default: 'fluid.layers.tanh'
        forget_bias(float|1.0): forget bias used when computing forget gate. This 
            is not used in default LSTMCell implementation (CUDNN compatiable)
        use_cudnn_impl(bool|True): whether to use CUDNN compatible LSTMCell
X
Xing Wu 已提交
79
        dtype(string): data type used in this cell
X
Xing Wu 已提交
80 81 82
    
    Returns:
        None
X
Xing Wu 已提交
83

X
Xing Wu 已提交
84
    Examples:
X
Xing Wu 已提交
85

X
Xing Wu 已提交
86
        .. code-block:: python
X
Xing Wu 已提交
87

X
Xing Wu 已提交
88 89
            from paddle import fluid
            import paddle.fluid.core as core
X
Xing Wu 已提交
90
            from paddle.fluid.dygraph import LSTMCell
X
Xing Wu 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
            import numpy as np
            batch_size = 64
            input_size = 128
            hidden_size = 256
            step_input_np = np.random.uniform(-0.1, 0.1, (
                batch_size, input_size)).astype('float64')
            pre_hidden_np = np.random.uniform(-0.1, 0.1, (
                batch_size, hidden_size)).astype('float64')
            pre_cell_np = np.random.uniform(-0.1, 0.1, (
                batch_size, hidden_size)).astype('float64')
            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
            else:
                place = core.CPUPlace()
            with fluid.dygraph.guard(place):
                cudnn_lstm = LSTMCell(hidden_size, input_size)
                step_input_var = fluid.dygraph.to_variable(step_input_np)
                pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np)
                pre_cell_var = fluid.dygraph.to_variable(pre_cell_np)
                new_hidden, new_cell = cudnn_lstm(step_input_var, pre_hidden_var, pre_cell_var) 
X
Xing Wu 已提交
111

X
Xing Wu 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    """

    def __init__(self,
                 hidden_size,
                 input_size,
                 param_attr=None,
                 bias_attr=None,
                 gate_activation=None,
                 activation=None,
                 forget_bias=1.0,
                 use_cudnn_impl=True,
                 dtype='float64'):
        super(LSTMCell, self).__init__(dtype)

        self._hidden_size = hidden_size
        self._input_size = input_size
        self._param_attr = param_attr
        self._bias_attr = bias_attr
        self._dtype = dtype
X
Xing Wu 已提交
131 132
        self._gate_activation = gate_activation or sigmoid
        self._activation = activation or tanh
X
Xing Wu 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
        self._use_cudnn_impl = use_cudnn_impl

        if self._use_cudnn_impl:

            if self._param_attr is not None and self._param_attr.name is not None:
                weight_ih_param_attr = copy.deepcopy(self._param_attr)
                weight_hh_param_attr = copy.deepcopy(self._param_attr)
                weight_ih_param_attr.name += "_weight_ih"
                weight_hh_param_attr.name += "_weight_hh"
            else:
                weight_ih_param_attr = self._param_attr
                weight_hh_param_attr = self._param_attr

            if self._bias_attr is not None and self._bias_attr.name is not None:
                bias_ih_param_attr = copy.deepcopy(self._bias_attr)
                bias_hh_param_attr = copy.deepcopy(self._bias_attr)
                bias_ih_param_attr.name += "_bias_ih"
                bias_hh_param_attr.name += "_bias_hh"
            else:
                bias_ih_param_attr = self._bias_attr
                bias_hh_param_attr = self._bias_attr

            self._weight_ih = self.create_parameter(
                attr=weight_ih_param_attr,
157
                shape=[4 * self._hidden_size, self._input_size],
X
Xing Wu 已提交
158 159 160 161
                dtype=self._dtype)

            self._weight_hh = self.create_parameter(
                attr=weight_hh_param_attr,
162
                shape=[4 * self._hidden_size, self._hidden_size],
X
Xing Wu 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
                dtype=self._dtype)

            self._bias_ih = self.create_parameter(
                attr=bias_ih_param_attr,
                shape=[4 * self._hidden_size],
                dtype=self._dtype,
                is_bias=True)
            self._bias_hh = self.create_parameter(
                attr=bias_hh_param_attr,
                shape=[4 * self._hidden_size],
                dtype=self._dtype,
                is_bias=True)

        else:

X
Xing Wu 已提交
178
            self._forget_bias = fill_constant(
X
Xing Wu 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
                [1], dtype=dtype, value=forget_bias)
            self._forget_bias.stop_gradient = False

            self._weight = self.create_parameter(
                attr=self._param_attr,
                shape=[
                    self._input_size + self._hidden_size, 4 * self._hidden_size
                ],
                dtype=dtype)

            self._bias = self.create_parameter(
                attr=self._bias_attr,
                shape=[4 * self._hidden_size],
                dtype=dtype,
                is_bias=True)

    def forward(self, input, pre_hidden, pre_cell):

        if self._use_cudnn_impl:
X
Xing Wu 已提交
198 199 200 201
            igates = matmul(input, y=self._weight_ih, transpose_y=True)
            igates = elementwise_add(igates, self._bias_ih)
            hgates = matmul(pre_hidden, self._weight_hh, transpose_y=True)
            hgates = elementwise_add(hgates, self._bias_hh)
X
Xing Wu 已提交
202

X
Xing Wu 已提交
203 204
            chunked_igates = split(igates, num_or_sections=4, dim=1)
            chunked_hgates = split(hgates, num_or_sections=4, dim=1)
X
Xing Wu 已提交
205

X
Xing Wu 已提交
206
            ingate = elementwise_add(chunked_igates[0], chunked_hgates[0])
X
Xing Wu 已提交
207 208
            ingate = self._gate_activation(ingate)

X
Xing Wu 已提交
209
            forgetgate = elementwise_add(chunked_igates[1], chunked_hgates[1])
X
Xing Wu 已提交
210 211
            forgetgate = self._gate_activation(forgetgate)

X
Xing Wu 已提交
212
            cellgate = elementwise_add(chunked_igates[2], chunked_hgates[2])
X
Xing Wu 已提交
213 214
            cellgate = self._activation(cellgate)

X
Xing Wu 已提交
215
            outgate = elementwise_add(chunked_igates[3], chunked_hgates[3])
X
Xing Wu 已提交
216 217 218 219 220 221 222
            outgate = self._gate_activation(outgate)

            new_cell = (forgetgate * pre_cell) + (ingate * cellgate)
            new_hidden = outgate * self._activation(new_cell)

        else:

X
Xing Wu 已提交
223 224 225 226 227 228 229 230 231 232
            concat_input_hidden = concat([input, pre_hidden], 1)
            gate_input = matmul(x=concat_input_hidden, y=self._weight)

            gate_input = elementwise_add(gate_input, self._bias)
            i, j, f, o = split(gate_input, num_or_sections=4, dim=-1)
            new_cell = elementwise_add(
                elementwise_mul(pre_cell,
                                self._gate_activation(
                                    elementwise_add(f, self._forget_bias))),
                elementwise_mul(sigmoid(i), tanh(j)))
X
Xing Wu 已提交
233 234 235 236 237 238
            new_hidden = self._activation(new_cell) * self._gate_activation(o)

        return new_hidden, new_cell


class GRUCell(Layer):
239
    r"""
X
Xing Wu 已提交
240 241 242
    GRU implementation using basic operators.
    There are two GRUCell version, the default one is compatible with CUDNN GRU implementation.
    The algorithm can be described as the equations below.
X
Xing Wu 已提交
243

X
Xing Wu 已提交
244
        .. math::
X
Xing Wu 已提交
245

X
Xing Wu 已提交
246
            u_t & = sigmoid(W_{ux} x_{t} + b_ux + W_{uh} h_{t-1} + b_uh)
X
Xing Wu 已提交
247

X
Xing Wu 已提交
248
            r_t & = sigmoid(W_{rx} x_{t} + b_rx + W_{rh} h_{t-1} + b_rh)
X
Xing Wu 已提交
249

X
Xing Wu 已提交
250
            \\tilde{h_{t}} & = tanh(W_{cx} x_{t} + b_cx + r_t \\odot (W_{ch} h_{t-1} + b_ch))
X
Xing Wu 已提交
251

X
Xing Wu 已提交
252
            h_t & = u_t h_{t-1} + (1-u_t) \\tilde{h_{t}}
X
Xing Wu 已提交
253

X
Xing Wu 已提交
254 255
    The other LSTMCell version is compatible with the BasicGRUUnit used in static graph.
    The algorithm can be described as the equations below.
X
Xing Wu 已提交
256 257 258

        .. math::

X
Xing Wu 已提交
259
            u_t & = sigmoid(W_{ux} x_{t} + W_{uh} h_{t-1} + b_u)
X
Xing Wu 已提交
260

X
Xing Wu 已提交
261
            r_t & = sigmoid(W_{rx} x_{t} + W_{rh} h_{t-1} + b_r)
X
Xing Wu 已提交
262

X
Xing Wu 已提交
263
            \\tilde{h_{t}} & = tanh(W_{cx} x_{t} + W_{ch} \\odot(r_t, h_{t-1}) + b_m)
X
Xing Wu 已提交
264

X
Xing Wu 已提交
265
            h_t & = u_t h_{t-1} + (1-u_t) \\tilde{h_{t}}
X
Xing Wu 已提交
266

X
Xing Wu 已提交
267 268 269 270 271
    Args:
        hidden_size (integer): The hidden size used in the Cell.
        input_size (integer): The input size used in the Cell.
        param_attr(ParamAttr|None): The parameter attribute for the learnable
            weight matrix. Note:
X
Xing Wu 已提交
272
            If it is set to None or one attribute of ParamAttr, GRUCell will
X
Xing Wu 已提交
273 274 275
            create ParamAttr as param_attr. If the Initializer of the param_attr
            is not set, the parameter is initialized with Xavier. Default: None.
        bias_attr (ParamAttr|None): The parameter attribute for the bias
X
Xing Wu 已提交
276 277
            of GRUCell.
            If it is set to None or one attribute of ParamAttr, GRUCell will 
X
Xing Wu 已提交
278 279 280 281 282 283 284
            create ParamAttr as bias_attr. If the Initializer of the bias_attr
            is not set, the bias is initialized zero. Default: None.
        gate_activation (function|None): The activation function for gates (actGate).
                                  Default: 'fluid.layers.sigmoid'
        activation (function|None): The activation function for cell (actNode).
                             Default: 'fluid.layers.tanh'
        use_cudnn_impl(bool|True): whether to use CUDNN compatible LSTMCell
X
Xing Wu 已提交
285
        dtype(string): data type used in this cell
X
Xing Wu 已提交
286 287 288
    
    Returns:
        None
X
Xing Wu 已提交
289

X
Xing Wu 已提交
290
    Examples:
X
Xing Wu 已提交
291

X
Xing Wu 已提交
292
        .. code-block:: python
X
Xing Wu 已提交
293

X
Xing Wu 已提交
294 295
            from paddle import fluid
            import paddle.fluid.core as core
X
Xing Wu 已提交
296
            from paddle.fluid.dygraph import GRUCell
X
Xing Wu 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
            import numpy as np
            batch_size = 64
            input_size = 128
            hidden_size = 256
            step_input_np = np.random.uniform(-0.1, 0.1, (
            batch_size, input_size)).astype('float64')
            pre_hidden_np = np.random.uniform(-0.1, 0.1, (
            batch_size, hidden_size)).astype('float64')
            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
            else:
                place = core.CPUPlace()
            with fluid.dygraph.guard(place):
                cudnn_gru = GRUCell(hidden_size, input_size)
                step_input_var = fluid.dygraph.to_variable(step_input_np)
                pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np)
X
Xing Wu 已提交
313

X
Xing Wu 已提交
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
    """

    def __init__(self,
                 hidden_size,
                 input_size,
                 param_attr=None,
                 bias_attr=None,
                 gate_activation=None,
                 activation=None,
                 use_cudnn_impl=True,
                 dtype='float64'):
        super(GRUCell, self).__init__()

        self._hidden_size = hidden_size
        self._input_size = input_size
        self._param_attr = param_attr
        self._bias_attr = bias_attr
        self._dtype = dtype
X
Xing Wu 已提交
332 333
        self._gate_activation = gate_activation or sigmoid
        self._activation = activation or tanh
X
Xing Wu 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
        self._use_cudnn_impl = use_cudnn_impl

        if self._use_cudnn_impl:

            if self._param_attr is not None and self._param_attr.name is not None:
                weight_ih_param_attr = copy.deepcopy(self._param_attr)
                weight_hh_param_attr = copy.deepcopy(self._param_attr)
                weight_ih_param_attr.name += "_weight_ih"
                weight_hh_param_attr.name += "_weight_hh"
            else:
                weight_ih_param_attr = self._param_attr
                weight_hh_param_attr = self._param_attr

            if self._bias_attr is not None and self._bias_attr.name is not None:
                bias_ih_param_attr = copy.deepcopy(self._bias_attr)
                bias_hh_param_attr = copy.deepcopy(self._bias_attr)
                bias_ih_param_attr.name += "_bias_ih"
                bias_hh_param_attr.name += "_bias_hh"
            else:
                bias_ih_param_attr = self._bias_attr
                bias_hh_param_attr = self._bias_attr

            self._weight_ih = self.create_parameter(
                attr=weight_ih_param_attr,
358
                shape=[3 * self._hidden_size, self._input_size],
X
Xing Wu 已提交
359 360 361 362
                dtype=self._dtype)

            self._weight_hh = self.create_parameter(
                attr=weight_hh_param_attr,
363
                shape=[3 * self._hidden_size, self._hidden_size],
X
Xing Wu 已提交
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
                dtype=self._dtype)

            self._bias_ih = self.create_parameter(
                attr=bias_ih_param_attr,
                shape=[3 * self._hidden_size],
                dtype=self._dtype,
                is_bias=True)
            self._bias_hh = self.create_parameter(
                attr=bias_hh_param_attr,
                shape=[3 * self._hidden_size],
                dtype=self._dtype,
                is_bias=True)

        else:

            if self._param_attr is not None and self._param_attr.name is not None:
                gate_weight_param_attr = copy.deepcopy(self._param_attr)
                candidate_weight_param_attr = copy.deepcopy(self._param_attr)
                gate_weight_param_attr.name += "_gate_weight"
                candidate_weight_param_attr.name += "_candidate_weight"
            else:
                gate_weight_param_attr = self._param_attr
                candidate_weight_param_attr = self._param_attr

            if self._bias_attr is not None and self._bias_attr.name is not None:
                gate_bias_param_attr = copy.deepcopy(self._bias_attr)
                candidate_bias_param_attr = copy.deepcopy(self._bias_attr)
                gate_bias_param_attr.name += "_gate_bias"
                candidate_bias_param_attr.name += "_candidate_bias"
            else:
                gate_bias_param_attr = self._bias_attr
                candidate_bias_param_attr = self._bias_attr

            self._gate_weight = self.create_parameter(
                attr=gate_weight_param_attr,
                shape=[
                    self._input_size + self._hidden_size, 2 * self._hidden_size
                ],
                dtype=dtype)

            self._candidate_weight = self.create_parameter(
                attr=candidate_weight_param_attr,
                shape=[
                    self._input_size + self._hidden_size, self._hidden_size
                ],
                dtype=dtype)

            self._gate_bias = self.create_parameter(
                attr=gate_bias_param_attr,
                shape=[2 * self._hidden_size],
                dtype=dtype,
                is_bias=True)
            self._candidate_bias = self.create_parameter(
                attr=candidate_bias_param_attr,
                shape=[self._hidden_size],
                dtype=dtype,
                is_bias=True)

    def forward(self, input, pre_hidden):

        if self._use_cudnn_impl:

X
Xing Wu 已提交
426 427 428 429
            igates = matmul(input, y=self._weight_ih, transpose_y=True)
            igates = elementwise_add(igates, self._bias_ih)
            hgates = matmul(pre_hidden, self._weight_hh, transpose_y=True)
            hgates = elementwise_add(hgates, self._bias_hh)
X
Xing Wu 已提交
430

X
Xing Wu 已提交
431 432
            chunked_igates = split(igates, num_or_sections=3, dim=1)
            chunked_hgates = split(hgates, num_or_sections=3, dim=1)
X
Xing Wu 已提交
433

X
Xing Wu 已提交
434
            reset_gate = elementwise_add(chunked_igates[0], chunked_hgates[0])
X
Xing Wu 已提交
435 436
            reset_gate = self._gate_activation(reset_gate)

X
Xing Wu 已提交
437
            input_gate = elementwise_add(chunked_igates[1], chunked_hgates[1])
X
Xing Wu 已提交
438 439 440
            input_gate = self._gate_activation(input_gate)

            _temp = reset_gate * chunked_hgates[2]
X
Xing Wu 已提交
441
            new_gate = elementwise_add(chunked_igates[2], _temp)
X
Xing Wu 已提交
442 443 444 445 446 447
            new_gate = self._activation(new_gate)

            new_hidden = (pre_hidden - new_gate) * input_gate + new_gate

        else:

X
Xing Wu 已提交
448
            concat_input_hidden = concat([input, pre_hidden], 1)
X
Xing Wu 已提交
449

X
Xing Wu 已提交
450
            gate_input = matmul(x=concat_input_hidden, y=self._gate_weight)
X
Xing Wu 已提交
451

X
Xing Wu 已提交
452
            gate_input = elementwise_add(gate_input, self._gate_bias)
X
Xing Wu 已提交
453
            gate_input = self._gate_activation(gate_input)
X
Xing Wu 已提交
454
            r, u = split(gate_input, num_or_sections=2, dim=1)
X
Xing Wu 已提交
455 456 457

            r_hidden = r * pre_hidden

X
Xing Wu 已提交
458 459 460
            candidate = matmul(
                concat([input, r_hidden], 1), self._candidate_weight)
            candidate = elementwise_add(candidate, self._candidate_bias)
X
Xing Wu 已提交
461 462 463 464 465

            c = self._activation(candidate)
            new_hidden = u * pre_hidden + (1 - u) * c

        return new_hidden