rnn.py 18.6 KB
Newer Older
X
Xing Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
X
Xing Wu 已提交
16
from . import Layer
17 18 19 20 21 22 23
from ..layers import (
    concat,
    fill_constant,
    matmul,
    elementwise_mul,
    split,
)
X
Xing Wu 已提交
24 25 26 27 28 29
import copy

__all__ = ['LSTMCell', 'GRUCell']


class LSTMCell(Layer):
30
    r"""
X
Xing Wu 已提交
31 32 33
    LSTMCell implementation using basic operators.
    There are two LSTMCell version, the default one is compatible with CUDNN LSTM implementation.
    The algorithm can be described as the equations below.
X
Xing Wu 已提交
34

X
Xing Wu 已提交
35
        .. math::
X
Xing Wu 已提交
36

X
Xing Wu 已提交
37
            i_t &= sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i)
X
Xing Wu 已提交
38

X
Xing Wu 已提交
39
            f_t &= sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f)
X
Xing Wu 已提交
40

X
Xing Wu 已提交
41
            o_t &= sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o)
X
Xing Wu 已提交
42

X
Xing Wu 已提交
43
            \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c)
X
Xing Wu 已提交
44

X
Xing Wu 已提交
45
            c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t}
X
Xing Wu 已提交
46

X
Xing Wu 已提交
47
            h_t &= o_t \\odot tanh(c_t)
X
Xing Wu 已提交
48

X
Xing Wu 已提交
49 50
    The other LSTMCell version is compatible with the BasicLSTMUnit used in static graph.
    The algorithm can be described as the equations below.
X
Xing Wu 已提交
51 52 53

        .. math::

X
Xing Wu 已提交
54
            i_t &= sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
X
Xing Wu 已提交
55

X
Xing Wu 已提交
56
            f_t &= sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
X
Xing Wu 已提交
57

X
Xing Wu 已提交
58
            o_t &= sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
X
Xing Wu 已提交
59

X
Xing Wu 已提交
60
            \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
X
Xing Wu 已提交
61

X
Xing Wu 已提交
62
            c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t}
X
Xing Wu 已提交
63

X
Xing Wu 已提交
64
            h_t &= o_t \\odot tanh(c_t)
X
Xing Wu 已提交
65

X
Xing Wu 已提交
66 67 68 69 70
    Args:
        hidden_size (integer): The hidden size used in the Cell.
        input_size (integer): The input size used in the Cell.
        param_attr(ParamAttr|None): The parameter attribute for the learnable
            weight matrix. Note:
X
Xing Wu 已提交
71
            If it is set to None or one attribute of ParamAttr, LSTMCell will
X
Xing Wu 已提交
72 73 74
            create ParamAttr as param_attr. If the Initializer of the param_attr
            is not set, the parameter is initialized with Xavier. Default: None.
        bias_attr (ParamAttr|None): The parameter attribute for the bias
X
Xing Wu 已提交
75
            of LSTMCell.
76
            If it is set to None or one attribute of ParamAttr, LSTMCell will
X
Xing Wu 已提交
77 78 79 80 81 82
            create ParamAttr as bias_attr. If the Initializer of the bias_attr
            is not set, the bias is initialized as zero. Default: None.
        gate_activation (function|None): The activation function for gates (actGate).
                                  Default: 'fluid.layers.sigmoid'
        activation (function|None): The activation function for cells (actNode).
                             Default: 'fluid.layers.tanh'
83
        forget_bias(float|1.0): forget bias used when computing forget gate. This
X
Xing Wu 已提交
84 85
            is not used in default LSTMCell implementation (CUDNN compatiable)
        use_cudnn_impl(bool|True): whether to use CUDNN compatible LSTMCell
X
Xing Wu 已提交
86
        dtype(string): data type used in this cell
87

X
Xing Wu 已提交
88 89
    Returns:
        None
X
Xing Wu 已提交
90

X
Xing Wu 已提交
91
    Examples:
X
Xing Wu 已提交
92

X
Xing Wu 已提交
93
        .. code-block:: python
X
Xing Wu 已提交
94

X
Xing Wu 已提交
95 96
            from paddle import fluid
            import paddle.fluid.core as core
X
Xing Wu 已提交
97
            from paddle.fluid.dygraph import LSTMCell
X
Xing Wu 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
            import numpy as np
            batch_size = 64
            input_size = 128
            hidden_size = 256
            step_input_np = np.random.uniform(-0.1, 0.1, (
                batch_size, input_size)).astype('float64')
            pre_hidden_np = np.random.uniform(-0.1, 0.1, (
                batch_size, hidden_size)).astype('float64')
            pre_cell_np = np.random.uniform(-0.1, 0.1, (
                batch_size, hidden_size)).astype('float64')
            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
            else:
                place = core.CPUPlace()
            with fluid.dygraph.guard(place):
                cudnn_lstm = LSTMCell(hidden_size, input_size)
                step_input_var = fluid.dygraph.to_variable(step_input_np)
                pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np)
                pre_cell_var = fluid.dygraph.to_variable(pre_cell_np)
117
                new_hidden, new_cell = cudnn_lstm(step_input_var, pre_hidden_var, pre_cell_var)
X
Xing Wu 已提交
118

X
Xing Wu 已提交
119 120
    """

121 122 123 124 125 126 127 128 129 130 131 132
    def __init__(
        self,
        hidden_size,
        input_size,
        param_attr=None,
        bias_attr=None,
        gate_activation=None,
        activation=None,
        forget_bias=1.0,
        use_cudnn_impl=True,
        dtype='float64',
    ):
133
        super().__init__(dtype)
X
Xing Wu 已提交
134 135 136 137 138 139

        self._hidden_size = hidden_size
        self._input_size = input_size
        self._param_attr = param_attr
        self._bias_attr = bias_attr
        self._dtype = dtype
140 141
        self._gate_activation = gate_activation or paddle.nn.functional.sigmoid
        self._activation = activation or paddle.tanh
X
Xing Wu 已提交
142 143 144 145
        self._use_cudnn_impl = use_cudnn_impl

        if self._use_cudnn_impl:

146 147 148 149
            if (
                self._param_attr is not None
                and self._param_attr.name is not None
            ):
X
Xing Wu 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
                weight_ih_param_attr = copy.deepcopy(self._param_attr)
                weight_hh_param_attr = copy.deepcopy(self._param_attr)
                weight_ih_param_attr.name += "_weight_ih"
                weight_hh_param_attr.name += "_weight_hh"
            else:
                weight_ih_param_attr = self._param_attr
                weight_hh_param_attr = self._param_attr

            if self._bias_attr is not None and self._bias_attr.name is not None:
                bias_ih_param_attr = copy.deepcopy(self._bias_attr)
                bias_hh_param_attr = copy.deepcopy(self._bias_attr)
                bias_ih_param_attr.name += "_bias_ih"
                bias_hh_param_attr.name += "_bias_hh"
            else:
                bias_ih_param_attr = self._bias_attr
                bias_hh_param_attr = self._bias_attr

            self._weight_ih = self.create_parameter(
                attr=weight_ih_param_attr,
169
                shape=[4 * self._hidden_size, self._input_size],
170 171
                dtype=self._dtype,
            )
X
Xing Wu 已提交
172 173 174

            self._weight_hh = self.create_parameter(
                attr=weight_hh_param_attr,
175
                shape=[4 * self._hidden_size, self._hidden_size],
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
                dtype=self._dtype,
            )

            self._bias_ih = self.create_parameter(
                attr=bias_ih_param_attr,
                shape=[4 * self._hidden_size],
                dtype=self._dtype,
                is_bias=True,
            )
            self._bias_hh = self.create_parameter(
                attr=bias_hh_param_attr,
                shape=[4 * self._hidden_size],
                dtype=self._dtype,
                is_bias=True,
            )
X
Xing Wu 已提交
191 192 193

        else:

194 195 196
            self._forget_bias = fill_constant(
                [1], dtype=dtype, value=forget_bias
            )
X
Xing Wu 已提交
197 198 199 200 201
            self._forget_bias.stop_gradient = False

            self._weight = self.create_parameter(
                attr=self._param_attr,
                shape=[
202 203
                    self._input_size + self._hidden_size,
                    4 * self._hidden_size,
X
Xing Wu 已提交
204
                ],
205 206
                dtype=dtype,
            )
X
Xing Wu 已提交
207

208 209 210 211 212 213
            self._bias = self.create_parameter(
                attr=self._bias_attr,
                shape=[4 * self._hidden_size],
                dtype=dtype,
                is_bias=True,
            )
X
Xing Wu 已提交
214 215 216 217

    def forward(self, input, pre_hidden, pre_cell):

        if self._use_cudnn_impl:
X
Xing Wu 已提交
218
            igates = matmul(input, y=self._weight_ih, transpose_y=True)
219
            igates = paddle.add(igates, self._bias_ih)
X
Xing Wu 已提交
220
            hgates = matmul(pre_hidden, self._weight_hh, transpose_y=True)
221
            hgates = paddle.add(hgates, self._bias_hh)
X
Xing Wu 已提交
222

X
Xing Wu 已提交
223 224
            chunked_igates = split(igates, num_or_sections=4, dim=1)
            chunked_hgates = split(hgates, num_or_sections=4, dim=1)
X
Xing Wu 已提交
225

226
            ingate = paddle.add(chunked_igates[0], chunked_hgates[0])
X
Xing Wu 已提交
227 228
            ingate = self._gate_activation(ingate)

229
            forgetgate = paddle.add(chunked_igates[1], chunked_hgates[1])
X
Xing Wu 已提交
230 231
            forgetgate = self._gate_activation(forgetgate)

232
            cellgate = paddle.add(chunked_igates[2], chunked_hgates[2])
X
Xing Wu 已提交
233 234
            cellgate = self._activation(cellgate)

235
            outgate = paddle.add(chunked_igates[3], chunked_hgates[3])
X
Xing Wu 已提交
236 237 238 239 240 241 242
            outgate = self._gate_activation(outgate)

            new_cell = (forgetgate * pre_cell) + (ingate * cellgate)
            new_hidden = outgate * self._activation(new_cell)

        else:

X
Xing Wu 已提交
243 244 245
            concat_input_hidden = concat([input, pre_hidden], 1)
            gate_input = matmul(x=concat_input_hidden, y=self._weight)

246
            gate_input = paddle.add(gate_input, self._bias)
X
Xing Wu 已提交
247
            i, j, f, o = split(gate_input, num_or_sections=4, dim=-1)
248 249
            new_cell = paddle.add(
                paddle.multiply(
250
                    pre_cell,
251
                    self._gate_activation(paddle.add(f, self._forget_bias)),
252
                ),
253
                paddle.multiply(
254 255
                    paddle.nn.functional.sigmoid(i), paddle.tanh(j)
                ),
256
            )
X
Xing Wu 已提交
257 258 259 260 261 262
            new_hidden = self._activation(new_cell) * self._gate_activation(o)

        return new_hidden, new_cell


class GRUCell(Layer):
263
    r"""
X
Xing Wu 已提交
264 265 266
    GRU implementation using basic operators.
    There are two GRUCell version, the default one is compatible with CUDNN GRU implementation.
    The algorithm can be described as the equations below.
X
Xing Wu 已提交
267

X
Xing Wu 已提交
268
        .. math::
X
Xing Wu 已提交
269

X
Xing Wu 已提交
270
            u_t & = sigmoid(W_{ux} x_{t} + b_ux + W_{uh} h_{t-1} + b_uh)
X
Xing Wu 已提交
271

X
Xing Wu 已提交
272
            r_t & = sigmoid(W_{rx} x_{t} + b_rx + W_{rh} h_{t-1} + b_rh)
X
Xing Wu 已提交
273

X
Xing Wu 已提交
274
            \\tilde{h_{t}} & = tanh(W_{cx} x_{t} + b_cx + r_t \\odot (W_{ch} h_{t-1} + b_ch))
X
Xing Wu 已提交
275

X
Xing Wu 已提交
276
            h_t & = u_t h_{t-1} + (1-u_t) \\tilde{h_{t}}
X
Xing Wu 已提交
277

X
Xing Wu 已提交
278 279
    The other LSTMCell version is compatible with the BasicGRUUnit used in static graph.
    The algorithm can be described as the equations below.
X
Xing Wu 已提交
280 281 282

        .. math::

X
Xing Wu 已提交
283
            u_t & = sigmoid(W_{ux} x_{t} + W_{uh} h_{t-1} + b_u)
X
Xing Wu 已提交
284

X
Xing Wu 已提交
285
            r_t & = sigmoid(W_{rx} x_{t} + W_{rh} h_{t-1} + b_r)
X
Xing Wu 已提交
286

X
Xing Wu 已提交
287
            \\tilde{h_{t}} & = tanh(W_{cx} x_{t} + W_{ch} \\odot(r_t, h_{t-1}) + b_m)
X
Xing Wu 已提交
288

X
Xing Wu 已提交
289
            h_t & = u_t h_{t-1} + (1-u_t) \\tilde{h_{t}}
X
Xing Wu 已提交
290

X
Xing Wu 已提交
291 292 293 294 295
    Args:
        hidden_size (integer): The hidden size used in the Cell.
        input_size (integer): The input size used in the Cell.
        param_attr(ParamAttr|None): The parameter attribute for the learnable
            weight matrix. Note:
X
Xing Wu 已提交
296
            If it is set to None or one attribute of ParamAttr, GRUCell will
X
Xing Wu 已提交
297 298 299
            create ParamAttr as param_attr. If the Initializer of the param_attr
            is not set, the parameter is initialized with Xavier. Default: None.
        bias_attr (ParamAttr|None): The parameter attribute for the bias
X
Xing Wu 已提交
300
            of GRUCell.
301
            If it is set to None or one attribute of ParamAttr, GRUCell will
X
Xing Wu 已提交
302 303 304 305 306 307 308
            create ParamAttr as bias_attr. If the Initializer of the bias_attr
            is not set, the bias is initialized zero. Default: None.
        gate_activation (function|None): The activation function for gates (actGate).
                                  Default: 'fluid.layers.sigmoid'
        activation (function|None): The activation function for cell (actNode).
                             Default: 'fluid.layers.tanh'
        use_cudnn_impl(bool|True): whether to use CUDNN compatible LSTMCell
X
Xing Wu 已提交
309
        dtype(string): data type used in this cell
310

X
Xing Wu 已提交
311 312
    Returns:
        None
X
Xing Wu 已提交
313

X
Xing Wu 已提交
314
    Examples:
X
Xing Wu 已提交
315

X
Xing Wu 已提交
316
        .. code-block:: python
X
Xing Wu 已提交
317

X
Xing Wu 已提交
318 319
            from paddle import fluid
            import paddle.fluid.core as core
X
Xing Wu 已提交
320
            from paddle.fluid.dygraph import GRUCell
X
Xing Wu 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
            import numpy as np
            batch_size = 64
            input_size = 128
            hidden_size = 256
            step_input_np = np.random.uniform(-0.1, 0.1, (
            batch_size, input_size)).astype('float64')
            pre_hidden_np = np.random.uniform(-0.1, 0.1, (
            batch_size, hidden_size)).astype('float64')
            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
            else:
                place = core.CPUPlace()
            with fluid.dygraph.guard(place):
                cudnn_gru = GRUCell(hidden_size, input_size)
                step_input_var = fluid.dygraph.to_variable(step_input_np)
                pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np)
X
Xing Wu 已提交
337

X
Xing Wu 已提交
338 339
    """

340 341 342 343 344 345 346 347 348 349 350
    def __init__(
        self,
        hidden_size,
        input_size,
        param_attr=None,
        bias_attr=None,
        gate_activation=None,
        activation=None,
        use_cudnn_impl=True,
        dtype='float64',
    ):
351
        super().__init__()
X
Xing Wu 已提交
352 353 354 355 356 357

        self._hidden_size = hidden_size
        self._input_size = input_size
        self._param_attr = param_attr
        self._bias_attr = bias_attr
        self._dtype = dtype
358 359
        self._gate_activation = gate_activation or paddle.nn.functional.sigmoid
        self._activation = activation or paddle.tanh
X
Xing Wu 已提交
360 361 362 363
        self._use_cudnn_impl = use_cudnn_impl

        if self._use_cudnn_impl:

364 365 366 367
            if (
                self._param_attr is not None
                and self._param_attr.name is not None
            ):
X
Xing Wu 已提交
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
                weight_ih_param_attr = copy.deepcopy(self._param_attr)
                weight_hh_param_attr = copy.deepcopy(self._param_attr)
                weight_ih_param_attr.name += "_weight_ih"
                weight_hh_param_attr.name += "_weight_hh"
            else:
                weight_ih_param_attr = self._param_attr
                weight_hh_param_attr = self._param_attr

            if self._bias_attr is not None and self._bias_attr.name is not None:
                bias_ih_param_attr = copy.deepcopy(self._bias_attr)
                bias_hh_param_attr = copy.deepcopy(self._bias_attr)
                bias_ih_param_attr.name += "_bias_ih"
                bias_hh_param_attr.name += "_bias_hh"
            else:
                bias_ih_param_attr = self._bias_attr
                bias_hh_param_attr = self._bias_attr

            self._weight_ih = self.create_parameter(
                attr=weight_ih_param_attr,
387
                shape=[3 * self._hidden_size, self._input_size],
388 389
                dtype=self._dtype,
            )
X
Xing Wu 已提交
390 391 392

            self._weight_hh = self.create_parameter(
                attr=weight_hh_param_attr,
393
                shape=[3 * self._hidden_size, self._hidden_size],
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
                dtype=self._dtype,
            )

            self._bias_ih = self.create_parameter(
                attr=bias_ih_param_attr,
                shape=[3 * self._hidden_size],
                dtype=self._dtype,
                is_bias=True,
            )
            self._bias_hh = self.create_parameter(
                attr=bias_hh_param_attr,
                shape=[3 * self._hidden_size],
                dtype=self._dtype,
                is_bias=True,
            )
X
Xing Wu 已提交
409 410 411

        else:

412 413 414 415
            if (
                self._param_attr is not None
                and self._param_attr.name is not None
            ):
X
Xing Wu 已提交
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
                gate_weight_param_attr = copy.deepcopy(self._param_attr)
                candidate_weight_param_attr = copy.deepcopy(self._param_attr)
                gate_weight_param_attr.name += "_gate_weight"
                candidate_weight_param_attr.name += "_candidate_weight"
            else:
                gate_weight_param_attr = self._param_attr
                candidate_weight_param_attr = self._param_attr

            if self._bias_attr is not None and self._bias_attr.name is not None:
                gate_bias_param_attr = copy.deepcopy(self._bias_attr)
                candidate_bias_param_attr = copy.deepcopy(self._bias_attr)
                gate_bias_param_attr.name += "_gate_bias"
                candidate_bias_param_attr.name += "_candidate_bias"
            else:
                gate_bias_param_attr = self._bias_attr
                candidate_bias_param_attr = self._bias_attr

            self._gate_weight = self.create_parameter(
                attr=gate_weight_param_attr,
                shape=[
436 437
                    self._input_size + self._hidden_size,
                    2 * self._hidden_size,
X
Xing Wu 已提交
438
                ],
439 440
                dtype=dtype,
            )
X
Xing Wu 已提交
441 442 443

            self._candidate_weight = self.create_parameter(
                attr=candidate_weight_param_attr,
444
                shape=[self._input_size + self._hidden_size, self._hidden_size],
445 446
                dtype=dtype,
            )
X
Xing Wu 已提交
447 448 449 450 451

            self._gate_bias = self.create_parameter(
                attr=gate_bias_param_attr,
                shape=[2 * self._hidden_size],
                dtype=dtype,
452 453
                is_bias=True,
            )
X
Xing Wu 已提交
454 455 456 457
            self._candidate_bias = self.create_parameter(
                attr=candidate_bias_param_attr,
                shape=[self._hidden_size],
                dtype=dtype,
458 459
                is_bias=True,
            )
X
Xing Wu 已提交
460 461 462 463 464

    def forward(self, input, pre_hidden):

        if self._use_cudnn_impl:

X
Xing Wu 已提交
465
            igates = matmul(input, y=self._weight_ih, transpose_y=True)
466
            igates = paddle.add(igates, self._bias_ih)
X
Xing Wu 已提交
467
            hgates = matmul(pre_hidden, self._weight_hh, transpose_y=True)
468
            hgates = paddle.add(hgates, self._bias_hh)
X
Xing Wu 已提交
469

X
Xing Wu 已提交
470 471
            chunked_igates = split(igates, num_or_sections=3, dim=1)
            chunked_hgates = split(hgates, num_or_sections=3, dim=1)
X
Xing Wu 已提交
472

473
            reset_gate = paddle.add(chunked_igates[0], chunked_hgates[0])
X
Xing Wu 已提交
474 475
            reset_gate = self._gate_activation(reset_gate)

476
            input_gate = paddle.add(chunked_igates[1], chunked_hgates[1])
X
Xing Wu 已提交
477 478 479
            input_gate = self._gate_activation(input_gate)

            _temp = reset_gate * chunked_hgates[2]
480
            new_gate = paddle.add(chunked_igates[2], _temp)
X
Xing Wu 已提交
481 482 483 484 485 486
            new_gate = self._activation(new_gate)

            new_hidden = (pre_hidden - new_gate) * input_gate + new_gate

        else:

X
Xing Wu 已提交
487
            concat_input_hidden = concat([input, pre_hidden], 1)
X
Xing Wu 已提交
488

X
Xing Wu 已提交
489
            gate_input = matmul(x=concat_input_hidden, y=self._gate_weight)
X
Xing Wu 已提交
490

491
            gate_input = paddle.add(gate_input, self._gate_bias)
X
Xing Wu 已提交
492
            gate_input = self._gate_activation(gate_input)
X
Xing Wu 已提交
493
            r, u = split(gate_input, num_or_sections=2, dim=1)
X
Xing Wu 已提交
494 495 496

            r_hidden = r * pre_hidden

497 498 499
            candidate = matmul(
                concat([input, r_hidden], 1), self._candidate_weight
            )
500
            candidate = paddle.add(candidate, self._candidate_bias)
X
Xing Wu 已提交
501 502 503 504 505

            c = self._activation(candidate)
            new_hidden = u * pre_hidden + (1 - u) * c

        return new_hidden