From f4083010a756f0a9554fdfef26fea262e8978f91 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Thu, 27 Aug 2020 20:52:18 +0800 Subject: [PATCH] Add unified RNN APIs (#26588) * Add RNN related apis in paddl.nn test=develop * new rnn api, cell almost done * add new progresses in rnn APIs for 2.0 * refine rnn APIs and docstrings. * add unittets * disable gpu tests when paddle is not compiled with cuda support * remove unnecessary imports * fix docstring * add to no_sample wlist * backport to python2 to avoid yield from * add **kwargs, fix typos * update docstrings for birnn * rename argument for SimpleRNN and SimpleRNNCell, fix sample code * add default value for initial_states in fluid.layers.birnn Co-authored-by: guosheng --- python/paddle/fluid/layers/rnn.py | 288 +++- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/rnn/CMakeLists.txt | 6 + .../fluid/tests/unittests/rnn/__init__.py | 13 + .../fluid/tests/unittests/rnn/convert.py | 51 + .../fluid/tests/unittests/rnn/rnn_numpy.py | 516 +++++++ .../tests/unittests/rnn/test_rnn_cells.py | 166 ++ .../unittests/rnn/test_rnn_cells_static.py | 326 ++++ .../tests/unittests/rnn/test_rnn_nets.py | 269 ++++ .../unittests/rnn/test_rnn_nets_static.py | 470 ++++++ python/paddle/nn/__init__.py | 3 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/rnn.py | 8 +- python/paddle/nn/layer/__init__.py | 2 + python/paddle/nn/layer/rnn.py | 1331 ++++++++++++++++- tools/wlist.json | 15 +- 16 files changed, 3391 insertions(+), 76 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/rnn/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/rnn/__init__.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/convert.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index bc1368b562..fe8ed83923 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -38,6 +38,7 @@ __all__ = [ 'Decoder', 'BeamSearchDecoder', 'rnn', + 'birnn', 'dynamic_decode', 'DecodeHelper', 'TrainingHelper', @@ -438,61 +439,146 @@ def rnn(cell, is_reverse=False, **kwargs): """ - :api_attr: Static Graph - rnn creates a recurrent neural network specified by RNNCell `cell`, - which performs :code:`cell.call()` repeatedly until reaches to the maximum - length of `inputs`. - - Parameters: - cell(RNNCell): An instance of `RNNCell`. - inputs(Variable): A (possibly nested structure of) tensor variable[s]. - The shape of tensor should be `[batch_size, sequence_length, ...]` - for `time_major == False` or `[sequence_length, batch_size, ...]` - for `time_major == True`. It represents the inputs to be unrolled - in RNN. - initial_states(Variable, optional): A (possibly nested structure of) - tensor variable[s], representing the initial state for RNN. - If not provided, `cell.get_initial_states` would be used to produce - the initial state. Default None. - sequence_length(Variable, optional): A tensor with shape `[batch_size]`. - It stores real length of each instance, thus enables users to extract - the last valid state when past a batch element's sequence length for - correctness. If not provided, the paddings would be treated same as - non-padding inputs. Default None. - time_major(bool, optional): Indicate the data layout of Tensor included - in `input` and `output` tensors. If `False`, the data layout would - be batch major with shape `[batch_size, sequence_length, ...]`. If - `True`, the data layout would be time major with shape - `[sequence_length, batch_size, ...]`. Default: `False`. - is_reverse(bool, optional): Indicate whether to calculate in the reverse - order of input sequences. Default: `False`. - **kwargs: Additional keyword arguments. Arguments passed to `cell.call`. + which performs :code:`cell.call()` (for dygraph mode :code:`cell.forward`) + repeatedly until reaches to the maximum length of `inputs`. + + Arguments: + cell(RNNCellBase): An instance of `RNNCellBase`. + inputs(Tensor): the input sequences. + If time_major is True, the shape is + `[time_steps, batch_size, input_size]` + else the shape is `[batch_size, time_steps, input_size]`. + initial_states(Tensor|tuple|list, optional): the initial state of the + rnn cell. Tensor or a possibly nested structure of tensors. If not + provided, `cell.get_initial_states` would be called to produce + the initial state. Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + time_major (bool): Whether the first dimension of the input means the + time steps. Defaults to False. + is_reverse (bool, optional): Indicate whether to calculate in the reverse + order of input sequences. Defaults to False. + **kwargs: Additional keyword arguments to pass to `forward` of the cell. Returns: - tuple: A tuple( :code:`(final_outputs, final_states)` ) including the final \ - outputs and states, both are Tensor or nested structure of Tensor. \ - `final_outputs` has the same structure and data types as \ - the returned `outputs` of :code:`cell.call` , and each Tenser in `final_outputs` \ - stacks all time steps' counterpart in `outputs` thus has shape `[batch_size, sequence_length, ...]` \ - for `time_major == False` or `[sequence_length, batch_size, ...]` for `time_major == True`. \ - `final_states` is the counterpart at last time step of initial states, \ - thus has the same structure with it and has tensors with same shapes \ - and data types. + (outputs, final_states) + outputs (Tensor|list|tuple): the output sequence. Tensor or nested + structure of Tensors. + If `time_major` is True, the shape of each tensor in outpus is + `[time_steps, batch_size, hidden_size]`, else + `[batch_size, time_steps, hidden_size]`. + final_states (Tensor|list|tuple): final states. A (possibly nested structure of) + tensor[s], representing the final state for RNN. It has the same + structure of intial state. Each tensor in final states has the same + shape and dtype as the corresponding tensor in initial states. Examples: .. code-block:: python - - import paddle.fluid as fluid - inputs = fluid.data(name="inputs", - shape=[-1, 32, 128], - dtype="float32") - cell = fluid.layers.GRUCell(hidden_size=128) - outputs = fluid.layers.rnn(cell=cell, inputs=inputs) + import paddle + paddle.disable_static() + + cell = paddle.nn.SimpleRNNCell(16, 32) + + inputs = paddle.rand((4, 23, 16)) + prev_h = paddle.randn((4, 32)) + outputs, final_states = paddle.nn.functional.rnn(cell, inputs, prev_h) + """ + if in_dygraph_mode(): + return _rnn_dynamic_graph(cell, inputs, initial_states, sequence_length, + time_major, is_reverse, **kwargs) + else: + return _rnn_static_graph(cell, inputs, initial_states, sequence_length, + time_major, is_reverse, **kwargs) + + +class ArrayWrapper(object): + def __init__(self, x): + self.array = [x] + + def append(self, x): + self.array.append(x) + return self + + +def _maybe_copy(state, new_state, step_mask): + """update rnn state or just pass the old state through""" + new_state = nn.elementwise_mul(new_state, step_mask, axis=0) \ + + nn.elementwise_mul(state, (1 - step_mask), axis=0) + return new_state + + +def _transpose_batch_time(x): + perm = [1, 0] + list(range(2, len(x.shape))) + return nn.transpose(x, perm) + + +def _rnn_dynamic_graph(cell, + inputs, + initial_states=None, + sequence_length=None, + time_major=False, + is_reverse=False, + **kwargs): + time_step_index = 0 if time_major else 1 + flat_inputs = flatten(inputs) + time_steps = flat_inputs[0].shape[time_step_index] + + if not time_major: + inputs = map_structure(_transpose_batch_time, inputs) + + if sequence_length is not None: + mask = sequence_lod.sequence_mask( + sequence_length, maxlen=time_steps, dtype=inputs.dtype) + mask = nn.transpose(mask, [1, 0]) + + if is_reverse: + inputs = map_structure(lambda x: tensor.reverse(x, axis=[0]), inputs) + mask = tensor.reverse(mask, axis=[0]) \ + if sequence_length is not None else None + + states = initial_states + outputs = [] + for i in range(time_steps): + step_inputs = map_structure(lambda x: x[i], inputs) + step_outputs, new_states = cell(step_inputs, states, **kwargs) + if sequence_length is not None: + new_states = map_structure( + partial( + _maybe_copy, step_mask=mask[i]), states, new_states) + states = new_states + outputs = map_structure(lambda x: ArrayWrapper(x), + step_outputs) if i == 0 else map_structure( + lambda x, x_array: x_array.append(x), + step_outputs, outputs) + + final_outputs = map_structure( + lambda x: nn.stack(x.array, axis=time_step_index), + outputs) + + if is_reverse: + final_outputs = map_structure( + lambda x: tensor.reverse(x, axis=time_step_index), + final_outputs) + + final_states = new_states + return final_outputs, final_states + + +def _rnn_static_graph(cell, + inputs, + initial_states=None, + sequence_length=None, + time_major=False, + is_reverse=False, + **kwargs): check_type(inputs, 'inputs', (Variable, list, tuple), 'rnn') if isinstance(inputs, (list, tuple)): for i, input_x in enumerate(inputs): @@ -500,30 +586,10 @@ def rnn(cell, ['float32', 'float64'], 'rnn') check_type(initial_states, 'initial_states', (Variable, list, tuple, type(None)), 'rnn') - if isinstance(initial_states, (list, tuple)): - states = map_structure(lambda x: x, initial_states)[0] - for i, state in enumerate(states): - if isinstance(state, (list, tuple)): - for j, state_j in enumerate(state): - check_variable_and_dtype(state_j, 'state_j[' + str(j) + ']', - ['float32', 'float64'], 'rnn') - else: - check_variable_and_dtype(state, 'states[' + str(i) + ']', - ['float32', 'float64'], 'rnn') check_type(sequence_length, 'sequence_length', (Variable, type(None)), 'rnn') - def _maybe_copy(state, new_state, step_mask): - # TODO: use where_op - new_state = nn.elementwise_mul( - new_state, step_mask, axis=0) - nn.elementwise_mul( - state, (step_mask - 1), axis=0) - return new_state - - def _transpose_batch_time(x): - return nn.transpose(x, [1, 0] + list(range(2, len(x.shape)))) - def _switch_grad(x, stop=False): x.stop_gradient = stop return x @@ -582,6 +648,98 @@ def rnn(cell, return (final_outputs, final_states) +def birnn(cell_fw, + cell_bw, + inputs, + initial_states=None, + sequence_length=None, + time_major=False, + **kwargs): + """ + birnn creates a bidirectional recurrent neural network specified by + RNNCell `cell_fw` and `cell_bw`, which performs :code:`cell.call()` + (for dygraph mode :code:`cell.forward`) repeatedly until reaches to + the maximum length of `inputs` and then concat the ouputs for both RNNs + along the last axis. + + Arguments: + cell_fw(RNNCellBase): An instance of `RNNCellBase`. + cell_bw(RNNCellBase): An instance of `RNNCellBase`. + inputs(Tensor): the input sequences. + If time_major is True, the shape is + `[time_steps, batch_size, input_size]` + else the shape is `[batch_size, time_steps, input_size]`. + initial_states(tuple, optional): A tuple of initial states of + `cell_fw` and `cell_bw`. + If not provided, `cell.get_initial_states` would be called to + produce initial state for each cell. Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + time_major (bool): Whether the first dimension of the input means the + time steps. Defaults to False. + **kwargs: Additional keyword arguments to pass to `forward` of each cell. + + Returns: + (outputs, final_states) + outputs (Tensor): the outputs of the bidirectional RNN. It is the + concatenation of the outputs from the forward RNN and backward + RNN along the last axis. + If time major is True, the shape is `[time_steps, batch_size, size]`, + else the shape is `[batch_size, time_steps, size]`, where size is + `cell_fw.hidden_size + cell_bw.hidden_size`. + final_states (tuple): A tuple of the final states of the forward + cell and backward cell. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + cell_fw = paddle.nn.LSTMCell(16, 32) + cell_bw = paddle.nn.LSTMCell(16, 32) + + inputs = paddle.rand((4, 23, 16)) + hf, cf = paddle.rand((4, 32)), paddle.rand((4, 32)) + hb, cb = paddle.rand((4, 32)), paddle.rand((4, 32)) + initial_states = ((hf, cf), (hb, cb)) + outputs, final_states = paddle.nn.functional.birnn( + cell_fw, cell_bw, inputs, initial_states) + + """ + if initial_states is None: + states_fw = cell_fw.get_initial_states( + batch_ref=inputs, batch_dim_idx=1 if time_major else 0) + states_bw = cell_fw.get_initial_states( + batch_ref=inputs, batch_dim_idx=1 if time_major else 0) + else: + states_fw, states_bw = initial_states + outputs_fw, states_fw = rnn(cell_fw, + inputs, + states_fw, + sequence_length, + time_major=time_major, + **kwargs) + + outputs_bw, states_bw = rnn(cell_bw, + inputs, + states_bw, + sequence_length, + time_major=time_major, + is_reverse=True, + **kwargs) + + outputs = map_structure(lambda x, y: tensor.concat([x, y], -1), outputs_fw, + outputs_bw) + + final_states = (states_fw, states_bw) + return outputs, final_states + + class Decoder(object): """ :api_attr: Static Graph diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7e51ecf07e..6220bf62c7 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -542,6 +542,7 @@ endif() add_subdirectory(sequence) add_subdirectory(dygraph_to_static) +add_subdirectory(rnn) if (WITH_MKLDNN) add_subdirectory(mkldnn) diff --git a/python/paddle/fluid/tests/unittests/rnn/CMakeLists.txt b/python/paddle/fluid/tests/unittests/rnn/CMakeLists.txt new file mode 100644 index 0000000000..f71e04c09a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/rnn/__init__.py b/python/paddle/fluid/tests/unittests/rnn/__init__.py new file mode 100644 index 0000000000..abf198b97e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/tests/unittests/rnn/convert.py b/python/paddle/fluid/tests/unittests/rnn/convert.py new file mode 100644 index 0000000000..02f10694a4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/convert.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np + + +def convert_params_for_cell(np_cell, paddle_cell): + state = np_cell.parameters + for k, v in paddle_cell.named_parameters(): + v.set_value(state[k]) + + +def convert_params_for_cell_static(np_cell, paddle_cell, place): + state = np_cell.parameters + for k, v in paddle_cell.named_parameters(): + scope = paddle.static.global_scope() + tensor = scope.find_var(v.name).get_tensor() + tensor.set(state[k], place) + + +def convert_params_for_net(np_net, paddle_net): + for np_layer, paddle_layer in zip(np_net, paddle_net): + if hasattr(np_layer, "cell"): + convert_params_for_cell(np_layer.cell, paddle_layer.cell) + else: + convert_params_for_cell(np_layer.cell_fw, paddle_layer.cell_fw) + convert_params_for_cell(np_layer.cell_bw, paddle_layer.cell_bw) + + +def convert_params_for_net_static(np_net, paddle_net, place): + for np_layer, paddle_layer in zip(np_net, paddle_net): + if hasattr(np_layer, "cell"): + convert_params_for_cell_static(np_layer.cell, paddle_layer.cell, + place) + else: + convert_params_for_cell_static(np_layer.cell_fw, + paddle_layer.cell_fw, place) + convert_params_for_cell_static(np_layer.cell_bw, + paddle_layer.cell_bw, place) diff --git a/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py b/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py new file mode 100644 index 0000000000..7e0b8374b9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py @@ -0,0 +1,516 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math + + +class LayerMixin(object): + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +class LayerListMixin(LayerMixin): + def __init__(self, layers=None): + self._layers = list(layers) if layers else [] + + def append(self, layer): + self._layers.append(layer) + + def __iter__(self): + return iter(self._layers) + + +class SimpleRNNCell(LayerMixin): + def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + if nonlinearity == 'tanh': + self.nonlinearity = np.tanh + else: + self.nonlinearity = lambda x: np.maximum(x, 0.) + + self.parameters = dict() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = np.random.uniform(-std, std, ( + hidden_size, input_size)).astype('float64') + self.weight_hh = np.random.uniform(-std, std, ( + hidden_size, hidden_size)).astype('float64') + self.parameters['weight_ih'] = self.weight_ih + self.parameters['weight_hh'] = self.weight_hh + if bias: + self.bias_ih = np.random.uniform(-std, std, + (hidden_size, )).astype('float64') + self.bias_hh = np.random.uniform(-std, std, + (hidden_size, )).astype('float64') + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh + else: + self.bias_ih = None + self.bias_hh = None + + def init_state(self, inputs): + batch_size = inputs.shape[0] + return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) + + def forward(self, inputs, hx=None): + if hx is None: + hx = self.init_state(inputs) + pre_h = hx + i2h = np.matmul(inputs, self.weight_ih.T) + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = np.matmul(pre_h, self.weight_hh.T) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self.nonlinearity(i2h + h2h) + return h, h + + +class GRUCell(LayerMixin): + def __init__(self, input_size, hidden_size, bias=True): + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.parameters = dict() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = np.random.uniform(-std, std, ( + 3 * hidden_size, input_size)).astype('float64') + self.weight_hh = np.random.uniform(-std, std, ( + 3 * hidden_size, hidden_size)).astype('float64') + self.parameters['weight_ih'] = self.weight_ih + self.parameters['weight_hh'] = self.weight_hh + if bias: + self.bias_ih = np.random.uniform(-std, std, ( + 3 * hidden_size)).astype('float64') + self.bias_hh = np.random.uniform(-std, std, ( + 3 * hidden_size)).astype('float64') + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh + else: + self.bias_ih = None + self.bias_hh = None + + def init_state(self, inputs): + batch_size = inputs.shape[0] + return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) + + def forward(self, inputs, hx=None): + if hx is None: + hx = self.init_state(inputs) + pre_hidden = hx + x_gates = np.matmul(inputs, self.weight_ih.T) + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = np.matmul(pre_hidden, self.weight_hh.T) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = np.split(x_gates, 3, 1) + h_r, h_z, h_c = np.split(h_gates, 3, 1) + + r = 1.0 / (1.0 + np.exp(-(x_r + h_r))) + z = 1.0 / (1.0 + np.exp(-(x_z + h_z))) + c = np.tanh(x_c + r * h_c) # apply reset gate after mm + h = (pre_hidden - c) * z + c + return h, h + + +class LSTMCell(LayerMixin): + def __init__(self, input_size, hidden_size, bias=True): + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.parameters = dict() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = np.random.uniform(-std, std, ( + 4 * hidden_size, input_size)).astype('float64') + self.weight_hh = np.random.uniform(-std, std, ( + 4 * hidden_size, hidden_size)).astype('float64') + self.parameters['weight_ih'] = self.weight_ih + self.parameters['weight_hh'] = self.weight_hh + if bias: + self.bias_ih = np.random.uniform(-std, std, ( + 4 * hidden_size)).astype('float64') + self.bias_hh = np.random.uniform(-std, std, ( + 4 * hidden_size)).astype('float64') + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh + else: + self.bias_ih = None + self.bias_hh = None + + def init_state(self, inputs): + batch_size = inputs.shape[0] + init_h = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) + init_c = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) + return init_h, init_c + + def forward(self, inputs, hx=None): + if hx is None: + hx = self.init_state(inputs) + pre_hidden, pre_cell = hx + gates = np.matmul(inputs, self.weight_ih.T) + if self.bias_ih is not None: + gates = gates + self.bias_ih + gates += np.matmul(pre_hidden, self.weight_hh.T) + if self.bias_hh is not None: + gates = gates + self.bias_hh + + chunked_gates = np.split(gates, 4, -1) + + i = 1.0 / (1.0 + np.exp(-chunked_gates[0])) + f = 1.0 / (1.0 + np.exp(-chunked_gates[1])) + o = 1.0 / (1.0 + np.exp(-chunked_gates[3])) + c = f * pre_cell + i * np.tanh(chunked_gates[2]) + h = o * np.tanh(c) + + return h, (h, c) + + +def sequence_mask(lengths, max_len=None): + if max_len is None: + max_len = np.max(lengths) + else: + assert max_len >= np.max(lengths) + return np.arange(max_len) < np.expand_dims(lengths, -1) + + +def update_state(mask, new, old): + if not isinstance(old, (tuple, list)): + return np.where(mask, new, old) + else: + return tuple(map(lambda x, y: np.where(mask, x, y), new, old)) + + +def rnn(cell, + inputs, + initial_states, + sequence_length=None, + time_major=False, + is_reverse=False): + if not time_major: + inputs = np.transpose(inputs, [1, 0, 2]) + if is_reverse: + inputs = np.flip(inputs, 0) + + if sequence_length is None: + mask = None + else: + mask = np.transpose(sequence_mask(sequence_length), [1, 0]) + mask = np.expand_dims(mask, -1) + if is_reverse: + mask = np.flip(mask, 0) + + time_steps = inputs.shape[0] + state = initial_states + outputs = [] + for t in range(time_steps): + x_t = inputs[t] + if mask is not None: + m_t = mask[t] + y, new_state = cell(x_t, state) + y = np.where(m_t, y, 0.) + outputs.append(y) + state = update_state(m_t, new_state, state) + else: + y, new_state = cell(x_t, state) + outputs.append(y) + state = new_state + + outputs = np.stack(outputs) + final_state = state + + if is_reverse: + outputs = np.flip(outputs, 0) + if not time_major: + outputs = np.transpose(outputs, [1, 0, 2]) + return outputs, final_state + + +def birnn(cell_fw, + cell_bw, + inputs, + initial_states, + sequence_length=None, + time_major=False): + states_fw, states_bw = initial_states + outputs_fw, states_fw = rnn(cell_fw, + inputs, + states_fw, + sequence_length, + time_major=time_major) + + outputs_bw, states_bw = rnn(cell_bw, + inputs, + states_bw, + sequence_length, + time_major=time_major, + is_reverse=True) + + outputs = np.concatenate((outputs_fw, outputs_bw), -1) + final_states = (states_fw, states_bw) + return outputs, final_states + + +def flatten(nested): + return list(_flatten(nested)) + + +def _flatten(nested): + for item in nested: + if isinstance(item, (list, tuple)): + for subitem in _flatten(item): + yield subitem + else: + yield item + + +def unstack(array, axis=0): + num = array.shape[axis] + sub_arrays = np.split(array, num, axis) + return [np.squeeze(sub_array, axis) for sub_array in sub_arrays] + + +def dropout(array, p=0.5): + if p == 0.0: + return array + + mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype) + return array * (mask / (1 - p)) + + +def split_states(states, bidirectional=False, state_components=1): + if state_components == 1: + states = unstack(states) + if not bidirectional: + return states + else: + return list(zip(states[::2], states[1::2])) + else: + assert len(states) == state_components + states = tuple([unstack(item) for item in states]) + if not bidirectional: + return list(zip(*states)) + else: + states = list(zip(*states)) + return list(zip(states[::2], states[1::2])) + + +def concat_states(states, bidirectional=False, state_components=1): + if state_components == 1: + return np.stack(flatten(states)) + else: + states = flatten(states) + componnets = [] + for i in range(state_components): + componnets.append(states[i::state_components]) + return [np.stack(item) for item in componnets] + + +class RNN(LayerMixin): + def __init__(self, cell, is_reverse=False, time_major=False): + super(RNN, self).__init__() + self.cell = cell + if not hasattr(self.cell, "call"): + # for non-dygraph mode, `rnn` api uses cell.call + self.cell.call = self.cell.forward + self.is_reverse = is_reverse + self.time_major = time_major + + def forward(self, inputs, initial_states=None, sequence_length=None): + final_outputs, final_states = rnn(self.cell, + inputs, + initial_states=initial_states, + sequence_length=sequence_length, + time_major=self.time_major, + is_reverse=self.is_reverse) + return final_outputs, final_states + + +class BiRNN(LayerMixin): + def __init__(self, cell_fw, cell_bw, time_major=False): + super(BiRNN, self).__init__() + self.cell_fw = cell_fw + self.cell_bw = cell_bw + self.time_major = time_major + + def forward(self, + inputs, + initial_states=None, + sequence_length=None, + **kwargs): + if isinstance(initial_states, (list, tuple)): + assert len(initial_states) == 2, \ + "length of initial_states should be 2 when it is a list/tuple" + else: + initial_states = [initial_states, initial_states] + + outputs, final_states = birnn(self.cell_fw, self.cell_bw, inputs, + initial_states, sequence_length, + self.time_major) + return outputs, final_states + + +class RNNMixin(LayerListMixin): + def forward(self, inputs, initial_states=None, sequence_length=None): + batch_index = 1 if self.time_major else 0 + batch_size = inputs.shape[batch_index] + dtype = inputs.dtype + if initial_states is None: + state_shape = (self.num_layers * self.num_directions, batch_size, + self.hidden_size) + if self.state_components == 1: + initial_states = np.zeros(state_shape, dtype) + else: + initial_states = tuple([ + np.zeros(state_shape, dtype) + for _ in range(self.state_components) + ]) + + states = split_states(initial_states, self.num_directions == 2, + self.state_components) + final_states = [] + + for i, rnn_layer in enumerate(self): + if i > 0: + inputs = dropout(inputs, self.dropout) + outputs, final_state = rnn_layer(inputs, states[i], sequence_length) + final_states.append(final_state) + inputs = outputs + + final_states = concat_states(final_states, self.num_directions == 2, + self.state_components) + return outputs, final_states + + +class SimpleRNN(RNNMixin): + def __init__(self, + input_size, + hidden_size, + num_layers=1, + nonlinearity="tanh", + direction="forward", + dropout=0., + time_major=False): + super(SimpleRNN, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = SimpleRNNCell(input_size, hidden_size, nonlinearity) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = SimpleRNNCell(hidden_size, hidden_size, nonlinearity) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = SimpleRNNCell(input_size, hidden_size, nonlinearity) + cell_bw = SimpleRNNCell(input_size, hidden_size, nonlinearity) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = SimpleRNNCell(2 * hidden_size, hidden_size, + nonlinearity) + cell_bw = SimpleRNNCell(2 * hidden_size, hidden_size, + nonlinearity) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 1 + + +class LSTM(RNNMixin): + def __init__(self, + input_size, + hidden_size, + num_layers=1, + direction="forward", + dropout=0., + time_major=False): + super(LSTM, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = LSTMCell(input_size, hidden_size) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = LSTMCell(hidden_size, hidden_size) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = LSTMCell(input_size, hidden_size) + cell_bw = LSTMCell(input_size, hidden_size) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = LSTMCell(2 * hidden_size, hidden_size) + cell_bw = LSTMCell(2 * hidden_size, hidden_size) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 2 + + +class GRU(RNNMixin): + def __init__(self, + input_size, + hidden_size, + num_layers=1, + direction="forward", + dropout=0., + time_major=False): + super(GRU, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = GRUCell(input_size, hidden_size) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = GRUCell(hidden_size, hidden_size) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = GRUCell(input_size, hidden_size) + cell_bw = GRUCell(input_size, hidden_size) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = GRUCell(2 * hidden_size, hidden_size) + cell_bw = GRUCell(2 * hidden_size, hidden_size) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 1 diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py new file mode 100644 index 0000000000..8d2677229a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py @@ -0,0 +1,166 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.framework.set_default_dtype("float64") + +import numpy as np +import unittest + +from rnn_numpy import SimpleRNNCell, LSTMCell, GRUCell +from convert import convert_params_for_cell + + +class TestSimpleRNNCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestSimpleRNNCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = SimpleRNNCell(16, 32, bias=self.bias) + rnn2 = paddle.nn.SimpleRNNCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + convert_params_for_cell(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestGRUCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestGRUCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = GRUCell(16, 32, bias=self.bias) + rnn2 = paddle.nn.GRUCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + convert_params_for_cell(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestLSTMCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestLSTMCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = LSTMCell(16, 32, bias=self.bias) + rnn2 = paddle.nn.LSTMCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + convert_params_for_cell(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + prev_c = np.random.randn(4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + y2, (h2, c2) = rnn2( + paddle.to_variable(x), + (paddle.to_variable(prev_h), paddle.to_variable(prev_c))) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + + y1, (h1, c1) = rnn1(x) + y2, (h2, c2) = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for bias in [True, False]: + for device in devices: + for test_class in [TestSimpleRNNCell, TestGRUCell, TestLSTMCell]: + suite.addTest(test_class(bias, device)) + return suite diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py new file mode 100644 index 0000000000..948e47d5b9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py @@ -0,0 +1,326 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.framework.set_default_dtype("float64") + +import numpy as np +import unittest + +from convert import convert_params_for_cell_static +from rnn_numpy import SimpleRNNCell, LSTMCell, GRUCell + + +class TestSimpleRNNCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestSimpleRNNCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = SimpleRNNCell(16, 32, bias=self.bias) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.SimpleRNNCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_cell_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [-1, 32], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data, init_h) + + feed_dict = {x_data.name: x, init_h.name: prev_h} + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + + y1, h1 = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, + feed=feed_dict, + fetch_list=[y, h], + use_prune=True) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestGRUCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestGRUCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = GRUCell(16, 32, bias=self.bias) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.GRUCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_cell_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [-1, 32], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data, init_h) + + feed_dict = {x_data.name: x, init_h.name: prev_h} + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + + y1, h1 = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, + feed=feed_dict, + fetch_list=[y, h], + use_prune=True) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestLSTMCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestLSTMCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = LSTMCell(16, 32, bias=self.bias) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.LSTMCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_cell_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + prev_c = np.random.randn(4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [-1, 32], + dtype=paddle.framework.get_default_dtype()) + init_c = paddle.data( + "init_c", [-1, 32], + dtype=paddle.framework.get_default_dtype()) + y, (h, c) = rnn2(x_data, (init_h, init_c)) + + feed_dict = {x_data.name: x, init_h.name: prev_h, init_c.name: prev_c} + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c]) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + + y1, (h1, c1) = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + y, (h, c) = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, + feed=feed_dict, + fetch_list=[y, h, c], + use_prune=True) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for bias in [True, False]: + for device in devices: + for test_class in [TestSimpleRNNCell, TestGRUCell, TestLSTMCell]: + suite.addTest(test_class(bias, device)) + return suite diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py new file mode 100644 index 0000000000..ef297b3bb6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py @@ -0,0 +1,269 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.set_default_dtype("float64") +from paddle.fluid.layers import sequence_mask + +import numpy as np +import unittest + +from convert import convert_params_for_net +from rnn_numpy import SimpleRNN, LSTM, GRU + + +class TestSimpleRNN(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestSimpleRNN, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = SimpleRNN( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + rnn2 = paddle.nn.SimpleRNN( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + convert_params_for_net(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + seq_len = paddle.to_variable(sequence_length) + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y2, h2 = rnn2(paddle.to_variable(x), sequence_length=seq_len) + y2 = paddle.multiply(y2, mask, axis=0) + + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +class TestGRU(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestGRU, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = GRU(16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + rnn2 = paddle.nn.GRU(16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + convert_params_for_net(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + seq_len = paddle.to_variable(sequence_length) + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y2, h2 = rnn2(paddle.to_variable(x), sequence_length=seq_len) + y2 = paddle.multiply(y2, mask, axis=0) + + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +class TestLSTM(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestLSTM, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = LSTM( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + rnn2 = paddle.nn.LSTM( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + convert_params_for_net(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + prev_c = np.random.randn(2 * self.num_directions, 4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + y2, (h2, c2) = rnn2( + paddle.to_variable(x), + (paddle.to_variable(prev_h), paddle.to_variable(prev_c))) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, (h1, c1) = rnn1(x) + y2, (h2, c2) = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, (h1, c1) = rnn1(x, sequence_length=sequence_length) + + seq_len = paddle.to_variable(sequence_length) + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y2, (h2, c2) = rnn2(paddle.to_variable(x), sequence_length=seq_len) + y2 = paddle.multiply(y2, mask, axis=0) + + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for direction in ["forward", "backward", "bidirectional"]: + for time_major in [True, False]: + for device in devices: + for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: + suite.addTest(test_class(time_major, direction, device)) + return suite diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py new file mode 100644 index 0000000000..90ed6b8b4c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py @@ -0,0 +1,470 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.set_default_dtype("float64") +from paddle.fluid.layers import sequence_mask + +import numpy as np +import unittest + +from convert import convert_params_for_net_static +from rnn_numpy import SimpleRNN, LSTM, GRU + + +class TestSimpleRNN(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestSimpleRNN, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = SimpleRNN( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.SimpleRNN( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_net_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone().clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + + y1, h1 = rnn1(x, prev_h) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [2 * self.num_directions, -1, 32], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data, init_h) + + feed_dict = {x_data.name: x, init_h.name: prev_h} + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + seq_len = paddle.data("seq_len", [-1], dtype="int64") + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y, h = rnn2(x_data, sequence_length=seq_len) + y = paddle.multiply(y, mask, axis=0) + + feed_dict = {x_data.name: x, seq_len.name: sequence_length} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +class TestGRU(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestGRU, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = GRU(16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.GRU(16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_net_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + + y1, h1 = rnn1(x, prev_h) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [2 * self.num_directions, -1, 32], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data, init_h) + + feed_dict = {x_data.name: x, init_h.name: prev_h} + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + seq_len = paddle.data("seq_len", [-1], dtype="int64") + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y, h = rnn2(x_data, sequence_length=seq_len) + y = paddle.multiply(y, mask, axis=0) + + feed_dict = {x_data.name: x, seq_len.name: sequence_length} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestLSTM(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestLSTM, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = LSTM( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_net_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + prev_c = np.random.randn(2 * self.num_directions, 4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [2 * self.num_directions, -1, 32], + dtype=paddle.framework.get_default_dtype()) + init_c = paddle.data( + "init_c", [2 * self.num_directions, -1, 32], + dtype=paddle.framework.get_default_dtype()) + y, (h, c) = rnn2(x_data, (init_h, init_c)) + + feed_dict = {x_data.name: x, init_h.name: prev_h, init_c.name: prev_c} + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, (h1, c1) = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + y, (h, c) = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, (h1, c1) = rnn1(x, sequence_length=sequence_length) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + seq_len = paddle.data("seq_len", [-1], dtype="int64") + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y, (h, c) = rnn2(x_data, sequence_length=seq_len) + y = paddle.multiply(y, mask, axis=0) + + feed_dict = {x_data.name: x, seq_len.name: sequence_length} + + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for direction in ["forward", "backward", "bidirectional"]: + for time_major in [True, False]: + for device in devices: + for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: + suite.addTest(test_class(time_major, direction, device)) + return suite diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 62d389209b..645b211565 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -18,6 +18,7 @@ from .layer import norm from .functional import extension from .layer import common +from .layer import rnn from .utils import weight_norm_hook from . import initializer @@ -26,6 +27,7 @@ __all__ = [] __all__ += norm.__all__ __all__ += extension.__all__ __all__ += common.__all__ +__all__ += rnn.__all__ __all__ += weight_norm_hook.__all__ # TODO: define alias in nn directory @@ -136,6 +138,7 @@ from .layer.norm import InstanceNorm3d #DEFINE_ALIAS from .layer.norm import BatchNorm1d #DEFINE_ALIAS from .layer.norm import BatchNorm2d #DEFINE_ALIAS from .layer.norm import BatchNorm3d #DEFINE_ALIAS +from .layer.rnn import * # from .layer.rnn import RNNCell #DEFINE_ALIAS # from .layer.rnn import GRUCell #DEFINE_ALIAS # from .layer.rnn import LSTMCell #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 97a4d5432b..75e2da4cf7 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -177,6 +177,8 @@ from .pooling import pool2d #DEFINE_ALIAS from .pooling import pool3d #DEFINE_ALIAS from .pooling import adaptive_pool2d #DEFINE_ALIAS from .pooling import adaptive_pool3d #DEFINE_ALIAS +from .rnn import rnn #DEFINE_ALIAS +from .rnn import birnn #DEFINE_ALIAS from .pooling import avg_pool2d #DEFINE_ALIAS from .pooling import max_pool2d #DEFINE_ALIAS from .pooling import avg_pool3d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/rnn.py b/python/paddle/nn/functional/rnn.py index 520cf44360..b7a97bc5aa 100644 --- a/python/paddle/nn/functional/rnn.py +++ b/python/paddle/nn/functional/rnn.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: define function of recurrent neural network +from paddle.fluid.layers.rnn import rnn, birnn -__all__ = [ - # 'gru_unit', - # 'lstm', - # 'lstm_unit' -] +__all__ = ['rnn', 'birnn'] diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 2eb9358f7f..3399e4e34c 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -20,6 +20,7 @@ from . import conv from . import extension from . import activation from . import norm +from . import rnn from . import vision from . import distance from . import transformer @@ -30,6 +31,7 @@ from .conv import * from .extension import * from .activation import * from .norm import * +from .rnn import * from .vision import * from .transformer import * diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 4717609503..6f1c5f199a 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -12,10 +12,1333 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: define classes of recurrent neural network +import copy +import collections +import itertools +import six +import math +import sys +import warnings +from functools import partial, reduce + +import paddle +from paddle import framework +from paddle.nn import functional as F +from paddle.nn import initializer as I +from paddle.fluid.dygraph import Layer, LayerList +from paddle.fluid.layers import utils +from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as +from paddle.fluid.data_feeder import convert_dtype __all__ = [ - # 'RNNCell', - # 'GRUCell', - # 'LSTMCell' + 'RNNCellBase', + 'SimpleRNNCell', + 'LSTMCell', + 'GRUCell', + 'RNN', + 'BiRNN', + 'SimpleRNN', + 'LSTM', + 'GRU', ] + + +def split_states(states, bidirectional=False, state_components=1): + r""" + Split states of RNN network into possibly nested list or tuple of + states of each RNN cells of the RNN network. + + Arguments: + states (Tensor|tuple|list): the concatenated states for RNN network. + When `state_components` is 1, states in a Tensor with shape + `(L*D, N, C)` where `L` is the number of layers of the RNN + network, `D` is the number of directions of the RNN network(1 + for unidirectional RNNs and 2 for bidirectional RNNs), `N` is + the batch size of the input to the RNN network, `C` is the + hidden size of the RNN network. + + When `state_components` is larger than 1, `states` is a tuple of + `state_components` Tensors that meet the requirements described + above. + + For SimpleRNNs and GRUs, `state_components` is 1, and for LSTMs, + `state_components` is 2. + bidirectional (bool): whether the state is of a bidirectional RNN + network. Defaults to False. + state_components (int): the number of the components of the states. see + `states` above. Defaults to 1. + + Returns: + A nested list or tuple of RNN cell states. + If `bidirectional` is True, it can be indexed twice to get an RNN + cell state. The first index indicates the layer, the second index + indicates the direction. + If `bidirectional` is False, it can be indexed once to get an RNN + cell state. The index indicates the layer. + Note that if `state_components` is larger than 1, an RNN cell state + can be indexed one more time to get a tensor of shape(N, C), where + `N` is the batch size of the input to the RNN cell, and `C` is the + hidden size of the RNN cell. + """ + if state_components == 1: + states = paddle.unstack(states) + if not bidirectional: + return states + else: + return list(zip(states[::2], states[1::2])) + else: + assert len(states) == state_components + states = tuple([paddle.unstack(item) for item in states]) + if not bidirectional: + return list(zip(*states)) + else: + states = list(zip(*states)) + return list(zip(states[::2], states[1::2])) + + +def concat_states(states, bidirectional=False, state_components=1): + r""" + Concatenate a possibly nested list or tuple of RNN cell states into a + compact form. + + Arguments: + states (list|tuple): a possibly nested list or tuple of RNN cell + states. + If `bidirectional` is True, it can be indexed twice to get an + RNN cell state. The first index indicates the layer, the second + index indicates the direction. + If `bidirectional` is False, it can be indexed once to get an RNN + cell state. The index indicates the layer. + Note that if `state_components` is larger than 1, an RNN cell + state can be indexed one more time to get a tensor of shape(N, C), + where `N` is the batch size of the input to the RNN cell, and + `C` is the hidden size of the RNN cell. + bidirectional (bool): whether the state is of a bidirectional RNN + network. Defaults to False. + state_components (int): the number of the components of the states. see + `states` above. Defaults to 1. + + Returns: + Concatenated states for RNN network. + When `state_components` is 1, states in a Tensor with shape + `(L\*D, N, C)` where `L` is the number of layers of the RNN + network, `D` is the number of directions of the RNN network(1 for + unidirectional RNNs and 2 for bidirectional RNNs), `N` is the batch + size of the input to the RNN network, `C` is the hidden size of the + RNN network. + + """ + if state_components == 1: + return paddle.stack(flatten(states)) + else: + states = flatten(states) + componnets = [] + for i in range(state_components): + componnets.append(states[i::state_components]) + return [paddle.stack(item) for item in componnets] + + +class RNNCellBase(Layer): + r""" + RNNCellBase is the base class for abstraction representing the calculations + mapping the input and state to the output and new state. It is suitable to + and mostly used in RNN. + """ + + def get_initial_states(self, + batch_ref, + shape=None, + dtype=None, + init_value=0., + batch_dim_idx=0): + r""" + Generate initialized states according to provided shape, data type and + value. + Arguments: + batch_ref (Tensor): A tensor, which shape would be used to + determine the batch size, which is used to generate initial + states. For `batch_ref`'s shape d, `d[batch_dim_idx]` is + treated as batch size. + shape (list|tuple, optional): A (possibly nested structure of) shape[s], + where a shape is a list/tuple of integer). `-1` (for batch size) + will be automatically prepended if a shape does not starts with + it. If None, property `state_shape` will be used. Defaults to + None. + dtype (str|list|tuple, optional): A (possibly nested structure of) + data type[s]. The structure must be same as that of `shape`, + except when all tensors' in states has the same data type, a + single data type can be used. If None and property `cell.state_shape` + is not available, current default floating type of paddle is + used. Defaults to None. + init_value (float, optional): A float value used to initialize states. + Defaults to 0. + batch_dim_idx (int, optional): An integer indicating which + dimension of the of `batch_ref` represents batch. Defaults to 0. + Returns: + init_states (Tensor|tuple|list): tensor of the provided shape and + dtype, or list of tensors that each satisfies the requirements, + packed in the same structure as `shape` and `type` does. + """ + # TODO: use inputs and batch_size + batch_ref = flatten(batch_ref)[0] + + def _is_shape_sequence(seq): + if sys.version_info < (3, ): + integer_types = ( + int, + long, ) + else: + integer_types = (int, ) + """For shape, list/tuple of integer is the finest-grained objection""" + if (isinstance(seq, list) or isinstance(seq, tuple)): + if reduce(lambda flag, x: isinstance(x, integer_types) and flag, + seq, True): + return False + # TODO: Add check for the illegal + if isinstance(seq, dict): + return True + return (isinstance(seq, collections.Sequence) and + not isinstance(seq, six.string_types)) + + class Shape(object): + def __init__(self, shape): + self.shape = shape if shape[0] == -1 else ([-1] + list(shape)) + + # nested structure of shapes + states_shapes = self.state_shape if shape is None else shape + is_sequence_ori = utils.is_sequence + utils.is_sequence = _is_shape_sequence + states_shapes = map_structure(lambda shape: Shape(shape), states_shapes) + utils.is_sequence = is_sequence_ori + + # nested structure of dtypes + try: + states_dtypes = self.state_dtype if dtype is None else dtype + except NotImplementedError: + states_dtypes = framework.get_default_dtype() + if len(flatten(states_dtypes)) == 1: + dtype = flatten(states_dtypes)[0] + states_dtypes = map_structure(lambda shape: dtype, states_shapes) + + init_states = map_structure( + lambda shape, dtype: paddle.fluid.layers.fill_constant_batch_size_like( + input=batch_ref, + shape=shape.shape, + dtype=dtype, + value=init_value, + input_dim_idx=batch_dim_idx), states_shapes, states_dtypes) + return init_states + + @property + def state_shape(self): + r""" + Abstract method (property). + Used to initialize states. + A (possiblely nested structure of) shape[s], where a shape is a + list/tuple of integers (-1 for batch size would be automatically + inserted into a shape if shape is not started with it). + Not necessary to be implemented if states are not initialized by + `get_initial_states` or the `shape` argument is provided when using + `get_initial_states`. + """ + raise NotImplementedError( + "Please add implementaion for `state_shape` in the used cell.") + + @property + def state_dtype(self): + r""" + Abstract method (property). + Used to initialize states. + A (possiblely nested structure of) data types[s]. The structure must be + same as that of `shape`, except when all tensors' in states has the same + data type, a signle data type can be used. + Not necessary to be implemented if states are not initialized + by `get_initial_states` or the `dtype` argument is provided when using + `get_initial_states`. + """ + raise NotImplementedError( + "Please add implementaion for `state_dtype` in the used cell.") + + +class SimpleRNNCell(RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + + The formula used is as follows: + + .. math:: + h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Please refer to `Finding Structure in Time + `_ for more details. + + Arguments: + input_size (int): The input size. + hidden_size (int): The hidden size. + activation (str, optional): The activation in the SimpleRNN cell. + It can be `tanh` or `relu`. Defaults to `tanh`. + weight_ih_attr (ParamAttr, optional): The parameter attribute for + `weight_ih`. Default: None. + weight_hh_attr(ParamAttr, optional): The parameter attribute for + `weight_hh`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh`. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Parameters: + weight_ih (Parameter): shape (hidden_size, input_size), input to hidden + weight, corresponding to :math:`W_{ih}` in the formula. + weight_hh (Parameter): shape (hidden_size, hidden_size), hidden to + hidden weight, corresponding to :math:`W_{hh}` in the formula. + bias_ih (Parameter): shape (hidden_size, ), input to hidden bias, + corresponding to :math:`b_{ih}` in the formula. + bias_hh (Parameter): shape (hidden_size, ), hidden to hidden bias, + corresponding to :math:`b_{hh}` in the formula. + + Inputs: + inputs (Tensor): shape `[batch_size, input_size]`, the input, + corresponding to :math:`x_t` in the formula. + states (Tensor, optional): shape `[batch_size, hidden_size]`, the + previous hidden state, corresponding to :math:`h_{t-1}` in the + formula. When states is None, zero state is used. Defaults to + None. + + Returns: + (outputs, new_states) + outputs (Tensor): shape `[batch_size, hidden_size]`, the output, + corresponding to :math:`h_{t}` in the formula. + states (Tensor): shape `[batch_size, hidden_size]`, the new hidden + state, corresponding to :math:`h_{t}` in the formula. + + Notes: + All the weights and bias are initialized with `Uniform(-std, std)` by + default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more + information about parameter initialization, please refer to + :ref:`api_fluid_ParamAttr`. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + x = paddle.randn((4, 16)) + prev_h = paddle.randn((4, 32)) + + cell = paddle.nn.SimpleRNNCell(16, 32) + y, h = cell(x, prev_h) + + """ + + def __init__(self, + input_size, + hidden_size, + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(SimpleRNNCell, self).__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = self.create_parameter( + (hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std)) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = self.create_parameter( + (hidden_size, ), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.input_size = input_size + self.hidden_size = hidden_size + if activation not in ["tanh", "relu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = paddle.matmul(inputs, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h + + @property + def state_shape(self): + return (self.hidden_size, ) + + +class LSTMCell(RNNCellBase): + r""" + Long-Short Term Memory(LSTM) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + + The formula used is as follows: + + .. math:: + i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi}) + f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf}) + o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho}) + \\widetilde{c}_{t} & = \\tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg}) + c_{t} & = f_{t} \* c{t-1} + i{t} \* \\widetile{c}_{t} + h_{t} & = o_{t} \* \\tanh(c_{t}) + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Please refer to `An Empirical Exploration of Recurrent Network Architectures + `_ for more details. + + Arguments: + input_size (int): The input size. + hidden_size (int): The hidden size. + weight_ih_attr(ParamAttr, optional): The parameter attribute for + `weight_ih`. Default: None. + weight_hh_attr(ParamAttr, optional): The parameter attribute for + `weight_hh`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh`. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Parameters: + weight_ih (Parameter): shape (4 * hidden_size, input_size), input to + hidden weight, which corresponds to the concatenation of + :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula. + weight_hh (Parameter): shape (4 * hidden_size, hidden_size), hidden to + hidden weight, which corresponds to the concatenation of + :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. + bias_ih (Parameter): shape (4 * hidden_size, ), input to hidden bias, + which corresponds to the concatenation of + :math:`b_{ii}, b_{if}, b_{ig}, b_{io}` in the formula. + bias_hh (Parameter): shape (4 * hidden_size, ), hidden to hidden bias, + which corresponds to the concatenation of + :math:`b_{hi}, b_{hf}, b_{hg}, b_{ho}` in the formula. + + Inputs: + inputs (Tensor): shape `[batch_size, input_size]`, the input, + corresponding to :math:`x_t` in the formula. + states (tuple, optional): a tuple of two tensors, each of shape + `[batch_size, hidden_size]`, the previous hidden state, + corresponding to :math:`h_{t-1}, c_{t-1}` in the formula. + When states is None, zero state is used. Defaults to None. + + Returns: + (outputs, new_states) + outputs (Tensor): shape `[batch_size, hidden_size]`, the output, + corresponding to :math:`h_{t}` in the formula. + states (tuple): a tuple of two tensors, each of shape + `[batch_size, hidden_size]`, the new hidden states, + corresponding to :math:`h_{t}, c{t}` in the formula. + + Notes: + All the weights and bias are initialized with `Uniform(-std, std)` by + default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more + information about parameter initialization, please refer to + :ref:`api_fluid_ParamAttr`. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + x = paddle.randn((4, 16)) + prev_h = paddle.randn((4, 32)) + prev_c = paddle.randn((4, 32)) + + cell = paddle.nn.LSTMCell(16, 32) + y, (h, c) = cell(x, (prev_h, prev_c)) + + """ + + def __init__(self, + input_size, + hidden_size, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(LSTMCell, self).__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = self.create_parameter( + (4 * hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std)) + self.weight_hh = self.create_parameter( + (4 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = self.create_parameter( + (4 * hidden_size, ), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + self.bias_hh = self.create_parameter( + (4 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.tanh + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_hidden, pre_cell = states + gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + gates = gates + self.bias_ih + gates += paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + gates = gates + self.bias_hh + + chunked_gates = paddle.split(gates, num_or_sections=4, axis=-1) + + i = self._gate_activation(chunked_gates[0]) + f = self._gate_activation(chunked_gates[1]) + o = self._gate_activation(chunked_gates[3]) + c = f * pre_cell + i * self._activation(chunked_gates[2]) + h = o * self._activation(c) + + return h, (h, c) + + @property + def state_shape(self): + r""" + The `state_shape` of LSTMCell is a tuple with two shapes: + `((hidden_size, ), (hidden_size,))`. (-1 for batch size would be + automatically inserted into shape). These two shapes correspond + to :math:`h_{t-1}` and :math:`c_{t-1}` separately. + """ + return ((self.hidden_size, ), (self.hidden_size, )) + + +class GRUCell(RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + + The formula for GRU used is as follows: + + .. math:: + + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr}) + z_{t} & = \sigma(W_{iz)x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz}) + \\widetilde{h}_{t} & = \\tanh(W_{ic)x_{t} + b_{ic} + r_{t} \* (W_{hc}x_{t} + b{hc})) + h_{t} & = z_{t} \* h_{t-1} + (1 - z_{t}) \* \\widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Please refer to `An Empirical Exploration of Recurrent Network Architectures + `_ for more details. + + Parameters: + input_size (int): The input size.. + hidden_size (int): The hidden size. + weight_ih_attr(ParamAttr, optional): The parameter attribute for + `weight_ih`. Default: None. + weight_hh_attr(ParamAttr, optional): The parameter attribute for + `weight_hh`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh`. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Parameters: + weight_ih (Parameter): shape (3 * hidden_size, input_size), input to + hidden weight, which corresponds to the concatenation of + :math:`W_{ir}, W_{iz}, W_{ic}` in the formula. + weight_hh (Parameter): shape (3 * hidden_size, hidden_size), hidden to + hidden weight, which corresponds to the concatenation of + :math:`W_{hr}, W_{hz}, W_{hc}` in the formula. + bias_ih (Parameter): shape (3 * hidden_size, ), input to hidden bias, + which corresponds to the concatenation of + :math:`b_{ir}, b_{iz}, b_{ic}` in the formula. + bias_hh (Parameter): shape (3 * hidden_size, ), hidden to hidden bias, + which corresponds to the concatenation of + :math:`b_{hr}, b_{hz}, b_{hc}` in the formula. + + Inputs: + inputs (Tensor): A tensor with shape `[batch_size, input_size]`, + corresponding to :math:`x_t` in the formula. + states (Tensor): A tensor with shape `[batch_size, hidden_size]`. + corresponding to :math:`h_{t-1}` in the formula. + + Returns: + (outputs, new_states) + outputs (Tensor): shape `[batch_size, hidden_size]`, the output, + corresponding to :math:`h_{t}` in the formula. + states (Tensor): shape `[batch_size, hidden_size]`, the new hidden + state, corresponding to :math:`h_{t}` in the formula. + + Notes: + All the weights and bias are initialized with `Uniform(-std, std)` by + default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more + information about parameter initialization, please refer to + :ref:`api_fluid_ParamAttr`. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + x = paddle.randn((4, 16)) + prev_h = paddle.randn((4, 32)) + + cell = paddle.nn.GRUCell(16, 32) + y, h = cell(x, prev_h) + + """ + + def __init__(self, + input_size, + hidden_size, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(GRUCell, self).__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = self.create_parameter( + (3 * hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std)) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = self.create_parameter( + (3 * hidden_size, ), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.tanh + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states + x_gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) + h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) + + r = self._gate_activation(x_r + h_r) + z = self._gate_activation(x_z + h_z) + c = self._activation(x_c + r * h_c) # apply reset gate after mm + h = (pre_hidden - c) * z + c + + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class RNN(Layer): + r""" + Wrapper for RNN, which creates a recurrent neural network with an RNN cell. + It performs :code:`cell.forward()` repeatedly until reaches to the maximum + length of `inputs`. + + Arguments: + cell(RNNCellBase): An instance of `RNNCellBase`. + is_reverse (bool, optional): Indicate whether to calculate in the reverse + order of input sequences. Defaults to False. + time_major (bool): Whether the first dimension of the input means the + time steps. Defaults to False. + + Inputs: + inputs (Tensor): A (possibly nested structure of) tensor[s]. The input + sequences. + If time major is True, the shape is `[batch_size, time_steps, input_size]` + If time major is False, the shape is [time_steps, batch_size, input_size]` + where `input_size` is the input size of the cell. + initial_states (Tensor|list|tuple, optional): Tensor of a possibly + nested structure of tensors, representing the initial state for + the rnn cell. If not provided, `cell.get_initial_states` would be + called to produce the initial states. Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + **kwargs: Additional keyword arguments to pass to `forward` of the cell. + + Returns: + (outputs, final_states) + outputs (Tensor|list|tuple): the output sequences. + If `time_major` is True, the shape is + `[time_steps, batch_size, hidden_size]`, else + `[batch_size, time_steps, hidden_size]`. + final_states (Tensor|list|tuple): final states of the cell. Tensor or + a possibly nested structure of tensors which has the same structure + with intial state. Each tensor in final states has the same shape + and dtype as the corresponding tensor in initial states. + + Notes: + This class is a low level API for wrapping rnn cell into a RNN network. + Users should take care of the state of the cell. If `initial_states` is + passed to the `forward` method, make sure that it satisfies the + requirements of the cell. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + inputs = paddle.rand((4, 23, 16)) + prev_h = paddle.randn((4, 32)) + + cell = paddle.nn.SimpleRNNCell(16, 32) + rnn = paddle.nn.RNN(cell) + outputs, final_states = rnn(inputs, prev_h) + + """ + + def __init__(self, cell, is_reverse=False, time_major=False): + super(RNN, self).__init__() + self.cell = cell + if not hasattr(self.cell, "call"): + # for non-dygraph mode, `rnn` api uses cell.call + self.cell.call = self.cell.forward + self.is_reverse = is_reverse + self.time_major = time_major + + def forward(self, + inputs, + initial_states=None, + sequence_length=None, + **kwargs): + if initial_states is None: + initial_states = self.cell.get_initial_states( + batch_ref=inputs, + dtype=inputs.dtype, + batch_dim_idx=self.batch_index) + + final_outputs, final_states = F.rnn(self.cell, + inputs, + initial_states=initial_states, + sequence_length=sequence_length, + time_major=self.time_major, + is_reverse=self.is_reverse, + **kwargs) + return final_outputs, final_states + + +class BiRNN(Layer): + r""" + Wrapper for bidirectional RNN, which builds a bidiretional RNN given the + forward rnn cell and backward rnn cell. A BiRNN applies forward RNN and + backward RNN with coresponding cells separately and concats the outputs + along the last axis. + + Arguments: + cell_fw (RNNCellBase): A RNNCellBase instance used for forward RNN. + cell_bw (RNNCellBase): A RNNCellBase instance used for backward RNN. + time_major (bool): Whether the first dimension of the input means the + time steps. Defaults to False. + + Inputs: + inputs (Tensor): the input sequences of both RNN. + If time_major is True, the shape of is + `[time_steps, batch_size, input_size]`, else the shape is + `[batch_size, time_steps, input_size]`, where input_size is the + input size of both cells. + initial_states (list|tuple, optional): A tuple/list of the initial + states of the forward cell and backward cell. Defaults to None. + If not provided, `cell.get_initial_states` would be called to + produce the initial states for each cell. Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + **kwargs: Additional keyword arguments. Arguments passed to `forward` + for each cell. + + Outputs: + (outputs, final_states) + outputs (Tensor): the outputs of the bidirectional RNN. It is the + concatenation of the outputs from the forward RNN and backward + RNN along the last axis. + If time major is True, the shape is `[time_steps, batch_size, size]`, + else the shape is `[batch_size, time_steps, size]`, where size is + `cell_fw.hidden_size + cell_bw.hidden_size`. + final_states (tuple): A tuple of the final states of the forward + cell and backward cell. + + Notes: + This class is a low level API for wrapping rnn cells into a BiRNN + network. Users should take care of the states of the cells. + If `initial_states` is passed to the `forward` method, make sure that + it satisfies the requirements of the cells. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + cell_fw = paddle.nn.LSTMCell(16, 32) + cell_bw = paddle.nn.LSTMCell(16, 32) + rnn = paddle.nn.BiRNN(cell_fw, cell_bw) + + inputs = paddle.rand((2, 23, 16)) + outputs, final_states = rnn(inputs) + + """ + + def __init__(self, cell_fw, cell_bw, time_major=False): + super(BiRNN, self).__init__() + self.cell_fw = cell_fw + self.cell_bw = cell_bw + if cell_fw.input_size != cell_bw.input_size: + raise ValueError("input size of forward cell({}) does not equals" + "that of backward cell({})".format( + cell_fw.input_size, cell_bw.input_size)) + for cell in [self.cell_fw, self.cell_bw]: + if not hasattr(cell, "call"): + # for non-dygraph mode, `rnn` api uses cell.call + cell.call = cell.forward + self.time_major = time_major + + def forward(self, + inputs, + initial_states=None, + sequence_length=None, + **kwargs): + if isinstance(initial_states, (list, tuple)): + assert len(initial_states) == 2, \ + "length of initial_states should be 2 when it is a list/tuple" + else: + initial_states = [initial_states, initial_states] + + outputs, final_states = F.birnn(self.cell_fw, self.cell_bw, inputs, + initial_states, sequence_length, + self.time_major, **kwargs) + return outputs, final_states + + +class RNNMixin(LayerList): + r""" + A Mixin class for RNN networks. It provides `forward` method for SimpleRNN, + LSTM and GRU. + """ + + def forward(self, inputs, initial_states=None, sequence_length=None): + batch_index = 1 if self.time_major else 0 + dtype = inputs.dtype + if initial_states is None: + state_shape = (self.num_layers * self.num_directions, -1, + self.hidden_size) + if self.state_components == 1: + initial_states = paddle.fluid.layers.fill_constant_batch_size_like( + inputs, state_shape, dtype, 0, batch_index, 1) + else: + initial_states = tuple([ + paddle.fluid.layers.fill_constant_batch_size_like( + inputs, state_shape, dtype, 0, batch_index, 1) + for _ in range(self.state_components) + ]) + + states = split_states(initial_states, self.num_directions == 2, + self.state_components) + final_states = [] + + for i, rnn_layer in enumerate(self): + if i > 0: + inputs = F.dropout( + inputs, + self.dropout, + training=self.training, + mode="upscale_in_train") + outputs, final_state = rnn_layer(inputs, states[i], sequence_length) + final_states.append(final_state) + inputs = outputs + + final_states = concat_states(final_states, self.num_directions == 2, + self.state_components) + return outputs, final_states + + +class SimpleRNN(RNNMixin): + r""" + Multilayer Elman network(SimpleRNN). It takes input sequences and initial + states as inputs, and returns the output sequences and the final states. + + Each layer inside the SimpleRNN maps the input sequences and initial states + to the output sequences and final states in the following manner: at each + step, it takes step inputs(:math:`x_{t}`) and previous + states(:math:`h_{t-1}`) as inputs, and returns step outputs(:math:`y_{t}`) + and new states(:math:`h_{t}`). + + .. math:: + + h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Arguments: + input_size (int): The input size for the first layer's cell. + hidden_size (int): The hidden size for each layer's cell. + num_layers (int, optional): Number of layers. Defaults to 1. + activation (str, optional): The activation in each SimpleRNN cell. It can be + `tanh` or `relu`. Defaults to `tanh`. + direction (str, optional): The direction of the network. It can be "forward", + "backward" and "bidirectional". Defaults to "forward". + dropout (float, optional): The droput probability. Dropout is applied to the + input of each layer except for the first layer. Defaults to 0. + time_major (bool, optional): Whether the first dimension of the input means the + time steps. Defaults to False. + weight_ih_attr (ParamAttr, optional): The parameter attribute for + `weight_ih` of each cell. Defaults to None. + weight_hh_attr (ParamAttr, optional): The parameter attribute for + `weight_hh` of each cell. Defaults to None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih` of each cells. Defaults to None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh` of each cells. Defaults to None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Inputs: + inputs (Tensor): the input sequence. + If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, + else, the shape is `[batch_size, time_steps, hidden_size]`. + initial_states (Tensor, optional): the initial state. The shape is + `[num_lauers * num_directions, batch_size, hidden_size]`. + If initial_state is not given, zero initial states are used. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + + Returns: + (outputs, final_states) + outputs (Tensor): the output sequence. + If `time_major` is True, the shape is + `[time_steps, batch_size, num_directions * hidden_size]`, + else, the shape is + `[batch_size, time_steps, num_directions * hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + final_states (Tensor): final states. The shape is + `[num_lauers * num_directions, batch_size, hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + rnn = paddle.nn.SimpleRNN(16, 32, 2) + + x = paddle.randn((4, 23, 16)) + prev_h = paddle.randn((2, 4, 32)) + y, h = rnn(x, prev_h) + + """ + + def __init__(self, + input_size, + hidden_size, + num_layers=1, + activation="tanh", + direction="forward", + dropout=0., + time_major=False, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(SimpleRNN, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = SimpleRNNCell(input_size, hidden_size, activation, + weight_ih_attr, weight_hh_attr, bias_ih_attr, + bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = SimpleRNNCell(hidden_size, hidden_size, activation, + weight_ih_attr, weight_hh_attr, + bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = SimpleRNNCell(input_size, hidden_size, activation, + weight_ih_attr, weight_hh_attr, + bias_ih_attr, bias_hh_attr) + cell_bw = SimpleRNNCell(input_size, hidden_size, activation, + weight_ih_attr, weight_hh_attr, + bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = SimpleRNNCell( + 2 * hidden_size, hidden_size, activation, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = SimpleRNNCell( + 2 * hidden_size, hidden_size, activation, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 1 + + +class LSTM(RNNMixin): + r""" + Multilayer LSTM. It takes a sequence and an initial state as inputs, and + returns the output sequences and the final states. + + Each layer inside the LSTM maps the input sequences and initial states + to the output sequences and final states in the following manner: at each + step, it takes step inputs(:math:`x_{t}`) and previous + states(:math:`h_{t-1}, c_{t-1}`) as inputs, and returns step + outputs(:math:`y_{t}`) and new states(:math:`h_{t}, c_{t}`). + + .. math:: + + i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi}) + f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf}) + o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho}) + \\widetilde{c}_{t} & = \\tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg}) + c_{t} & = f_{t} \* c{t-1} + i{t} \* \\widetile{c}_{t} + h_{t} & = o_{t} \* \\tanh(c_{t}) + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Arguments: + input_size (int): The input size for the first layer's cell. + hidden_size (int): The hidden size for each layer's cell. + num_layers (int, optional): Number of layers. Defaults to 1. + direction (str, optional): The direction of the network. It can be + "forward", "backward" and "bidirectional". Defaults to "forward". + dropout (float, optional): The droput probability. Dropout is applied + to the input of each layer except for the first layer. Defaults to 0. + time_major (bool, optional): Whether the first dimension of the input + means the time steps. Defaults to False. + weight_ih_attr (ParamAttr, optional): The parameter attribute for + `weight_ih` of each cell. Default: None. + weight_hh_attr (ParamAttr, optional): The parameter attribute for + `weight_hh` of each cell. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih` of each cells. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh` of each cells. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Inputs: + inputs (Tensor): the input sequence. + If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, + else, the shape is `[batch_size, time_steps, hidden_size]`. + initial_states (tuple, optional): the initial state, a tuple of (h, c), + the shape of each is `[num_lauers * num_directions, batch_size, hidden_size]`. + If initial_state is not given, zero initial states are used. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whos time step + index are not less than the valid length are treated as paddings. + + Returns: + (outputs, final_states) + outputs (Tensor): the output sequence. + If `time_major` is True, the shape is + `[time_steps, batch_size, num_directions * hidden_size]`, + If `time_major` is False, the shape is + `[batch_size, time_steps, num_directions * hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + final_states (Tensor): the final state, a tuple of two tensors, h and c. + The shape of each is + `[num_lauers * num_directions, batch_size, hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + rnn = paddle.nn.LSTM(16, 32, 2) + + x = paddle.randn((4, 23, 16)) + prev_h = paddle.randn((2, 4, 32)) + prev_c = paddle.randn((2, 4, 32)) + y, (h, c) = rnn(x, (prev_h, prev_c)) + + """ + + def __init__(self, + input_size, + hidden_size, + num_layers=1, + direction="forward", + dropout=0., + time_major=False, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(LSTM, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = LSTMCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = LSTMCell(hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = LSTMCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = LSTMCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 2 + + +class GRU(RNNMixin): + r""" + Multilayer GRU. It takes input sequencse and initial states as inputs, and + returns the output sequences and the final states. + + Each layer inside the GRU maps the input sequences and initial states + to the output sequences and final states in the following manner: at each + step, it takes step inputs(:math:`x_{t}`) and previous + states(:math:`h_{t-1}`) as inputs, and returns step outputs(:math:`y_{t}`) + and new states(:math:`h_{t}`). + + .. math:: + + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr}) + z_{t} & = \sigma(W_{iz)x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz}) + \\widetilde{h}_{t} & = \\tanh(W_{ic)x_{t} + b_{ic} + r_{t} \* (W_{hc}x_{t} + b{hc})) + h_{t} & = z_{t} \* h_{t-1} + (1 - z_{t}) \* \\widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Arguments: + input_size (int): The input size for the first layer's cell. + hidden_size (int): The hidden size for each layer's cell. + num_layers (int, optional): Number of layers. Defaults to 1. + direction (str, optional): The direction of the network. It can be + "forward", "backward" and "bidirectional". Defaults to "forward". + dropout (float, optional): The droput probability. Dropout is applied + to the input of each layer except for the first layer. Defaults to 0. + time_major (bool, optional): Whether the first dimension of the input + means the time steps. Defaults to False. + weight_ih_attr (ParamAttr, optional): The parameter attribute for + `weight_ih` of each cell. Default: None. + weight_hh_attr (ParamAttr, optional): The parameter attribute for + `weight_hh` of each cell. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih` of each cells. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh` of each cells. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Inputs: + inputs (Tensor): the input sequence. + If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, + else, the shape is `[batch_size, time_steps, hidden_size]`. + initial_states (Tensor, optional): the initial state. The shape is + `[num_lauers * num_directions, batch_size, hidden_size]`. + If initial_state is not given, zero initial states are used. + Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whos time step + index are not less than the valid length are treated as paddings. + + Returns: + (outputs, final_states) + outputs (Tensor): the output sequence. + If `time_major` is True, the shape is + `[time_steps, batch_size, num_directions * hidden_size]`, + else, the shape is + `[batch_size, time_steps, num_directions * hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + final_states (Tensor): final states. The shape is + `[num_lauers * num_directions, batch_size, hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + rnn = paddle.nn.GRU(16, 32, 2) + + x = paddle.randn((4, 23, 16)) + prev_h = paddle.randn((2, 4, 32)) + y, h = rnn(x, prev_h) + + """ + + def __init__(self, + input_size, + hidden_size, + num_layers=1, + direction="forward", + dropout=0., + time_major=False, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(GRU, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = GRUCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = GRUCell(hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = GRUCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = GRUCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 1 diff --git a/tools/wlist.json b/tools/wlist.json index c6114918e5..ce6f5fb176 100644 --- a/tools/wlist.json +++ b/tools/wlist.json @@ -148,7 +148,20 @@ "Callback.on_eval_batch_end", "Callback.on_test_batch_begin", "Callback.on_test_batch_end", - "Model.prepare" + "Model.prepare", + "SimpleRNNCell", + "SimpleRNNCell.forward", + "LSTMCell", + "LSTMCell.forward", + "GRUCell", + "GRUCell.forward", + "SimpleRNN", + "GRU", + "LSTM", + "RNN", + "BiRNN", + "RNNCellBase", + "RNNCellBase.get_initial_states" ], "wlist_no_op_pass":[ "gelu", -- GitLab