未验证 提交 f4083010 编写于 作者: F Feiyu Chan 提交者: GitHub

Add unified RNN APIs (#26588)

* Add RNN related apis in paddl.nn
test=develop

* new rnn api, cell almost done

* add new progresses in rnn APIs for 2.0

* refine rnn APIs and docstrings.

* add unittets

* disable gpu tests when paddle is not compiled with cuda support

* remove unnecessary imports

* fix docstring

* add to no_sample wlist

* backport to python2 to avoid yield from

* add **kwargs, fix typos

* update docstrings for birnn

* rename argument for SimpleRNN and SimpleRNNCell, fix sample code

* add default value for initial_states in fluid.layers.birnn
Co-authored-by: Nguosheng <guosheng@baidu.com>
上级 f311d3c1
......@@ -38,6 +38,7 @@ __all__ = [
'Decoder',
'BeamSearchDecoder',
'rnn',
'birnn',
'dynamic_decode',
'DecodeHelper',
'TrainingHelper',
......@@ -438,61 +439,146 @@ def rnn(cell,
is_reverse=False,
**kwargs):
"""
:api_attr: Static Graph
rnn creates a recurrent neural network specified by RNNCell `cell`,
which performs :code:`cell.call()` repeatedly until reaches to the maximum
length of `inputs`.
Parameters:
cell(RNNCell): An instance of `RNNCell`.
inputs(Variable): A (possibly nested structure of) tensor variable[s].
The shape of tensor should be `[batch_size, sequence_length, ...]`
for `time_major == False` or `[sequence_length, batch_size, ...]`
for `time_major == True`. It represents the inputs to be unrolled
in RNN.
initial_states(Variable, optional): A (possibly nested structure of)
tensor variable[s], representing the initial state for RNN.
If not provided, `cell.get_initial_states` would be used to produce
the initial state. Default None.
sequence_length(Variable, optional): A tensor with shape `[batch_size]`.
It stores real length of each instance, thus enables users to extract
the last valid state when past a batch element's sequence length for
correctness. If not provided, the paddings would be treated same as
non-padding inputs. Default None.
time_major(bool, optional): Indicate the data layout of Tensor included
in `input` and `output` tensors. If `False`, the data layout would
be batch major with shape `[batch_size, sequence_length, ...]`. If
`True`, the data layout would be time major with shape
`[sequence_length, batch_size, ...]`. Default: `False`.
is_reverse(bool, optional): Indicate whether to calculate in the reverse
order of input sequences. Default: `False`.
**kwargs: Additional keyword arguments. Arguments passed to `cell.call`.
which performs :code:`cell.call()` (for dygraph mode :code:`cell.forward`)
repeatedly until reaches to the maximum length of `inputs`.
Arguments:
cell(RNNCellBase): An instance of `RNNCellBase`.
inputs(Tensor): the input sequences.
If time_major is True, the shape is
`[time_steps, batch_size, input_size]`
else the shape is `[batch_size, time_steps, input_size]`.
initial_states(Tensor|tuple|list, optional): the initial state of the
rnn cell. Tensor or a possibly nested structure of tensors. If not
provided, `cell.get_initial_states` would be called to produce
the initial state. Defaults to None.
sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64
or int32. The valid lengths of input sequences. Defaults to None.
If `sequence_length` is not None, the inputs are treated as
padded sequences. In each input sequence, elements whose time step
index are not less than the valid length are treated as paddings.
time_major (bool): Whether the first dimension of the input means the
time steps. Defaults to False.
is_reverse (bool, optional): Indicate whether to calculate in the reverse
order of input sequences. Defaults to False.
**kwargs: Additional keyword arguments to pass to `forward` of the cell.
Returns:
tuple: A tuple( :code:`(final_outputs, final_states)` ) including the final \
outputs and states, both are Tensor or nested structure of Tensor. \
`final_outputs` has the same structure and data types as \
the returned `outputs` of :code:`cell.call` , and each Tenser in `final_outputs` \
stacks all time steps' counterpart in `outputs` thus has shape `[batch_size, sequence_length, ...]` \
for `time_major == False` or `[sequence_length, batch_size, ...]` for `time_major == True`. \
`final_states` is the counterpart at last time step of initial states, \
thus has the same structure with it and has tensors with same shapes \
and data types.
(outputs, final_states)
outputs (Tensor|list|tuple): the output sequence. Tensor or nested
structure of Tensors.
If `time_major` is True, the shape of each tensor in outpus is
`[time_steps, batch_size, hidden_size]`, else
`[batch_size, time_steps, hidden_size]`.
final_states (Tensor|list|tuple): final states. A (possibly nested structure of)
tensor[s], representing the final state for RNN. It has the same
structure of intial state. Each tensor in final states has the same
shape and dtype as the corresponding tensor in initial states.
Examples:
.. code-block:: python
import paddle.fluid as fluid
inputs = fluid.data(name="inputs",
shape=[-1, 32, 128],
dtype="float32")
cell = fluid.layers.GRUCell(hidden_size=128)
outputs = fluid.layers.rnn(cell=cell, inputs=inputs)
import paddle
paddle.disable_static()
cell = paddle.nn.SimpleRNNCell(16, 32)
inputs = paddle.rand((4, 23, 16))
prev_h = paddle.randn((4, 32))
outputs, final_states = paddle.nn.functional.rnn(cell, inputs, prev_h)
"""
if in_dygraph_mode():
return _rnn_dynamic_graph(cell, inputs, initial_states, sequence_length,
time_major, is_reverse, **kwargs)
else:
return _rnn_static_graph(cell, inputs, initial_states, sequence_length,
time_major, is_reverse, **kwargs)
class ArrayWrapper(object):
def __init__(self, x):
self.array = [x]
def append(self, x):
self.array.append(x)
return self
def _maybe_copy(state, new_state, step_mask):
"""update rnn state or just pass the old state through"""
new_state = nn.elementwise_mul(new_state, step_mask, axis=0) \
+ nn.elementwise_mul(state, (1 - step_mask), axis=0)
return new_state
def _transpose_batch_time(x):
perm = [1, 0] + list(range(2, len(x.shape)))
return nn.transpose(x, perm)
def _rnn_dynamic_graph(cell,
inputs,
initial_states=None,
sequence_length=None,
time_major=False,
is_reverse=False,
**kwargs):
time_step_index = 0 if time_major else 1
flat_inputs = flatten(inputs)
time_steps = flat_inputs[0].shape[time_step_index]
if not time_major:
inputs = map_structure(_transpose_batch_time, inputs)
if sequence_length is not None:
mask = sequence_lod.sequence_mask(
sequence_length, maxlen=time_steps, dtype=inputs.dtype)
mask = nn.transpose(mask, [1, 0])
if is_reverse:
inputs = map_structure(lambda x: tensor.reverse(x, axis=[0]), inputs)
mask = tensor.reverse(mask, axis=[0]) \
if sequence_length is not None else None
states = initial_states
outputs = []
for i in range(time_steps):
step_inputs = map_structure(lambda x: x[i], inputs)
step_outputs, new_states = cell(step_inputs, states, **kwargs)
if sequence_length is not None:
new_states = map_structure(
partial(
_maybe_copy, step_mask=mask[i]), states, new_states)
states = new_states
outputs = map_structure(lambda x: ArrayWrapper(x),
step_outputs) if i == 0 else map_structure(
lambda x, x_array: x_array.append(x),
step_outputs, outputs)
final_outputs = map_structure(
lambda x: nn.stack(x.array, axis=time_step_index),
outputs)
if is_reverse:
final_outputs = map_structure(
lambda x: tensor.reverse(x, axis=time_step_index),
final_outputs)
final_states = new_states
return final_outputs, final_states
def _rnn_static_graph(cell,
inputs,
initial_states=None,
sequence_length=None,
time_major=False,
is_reverse=False,
**kwargs):
check_type(inputs, 'inputs', (Variable, list, tuple), 'rnn')
if isinstance(inputs, (list, tuple)):
for i, input_x in enumerate(inputs):
......@@ -500,30 +586,10 @@ def rnn(cell,
['float32', 'float64'], 'rnn')
check_type(initial_states, 'initial_states',
(Variable, list, tuple, type(None)), 'rnn')
if isinstance(initial_states, (list, tuple)):
states = map_structure(lambda x: x, initial_states)[0]
for i, state in enumerate(states):
if isinstance(state, (list, tuple)):
for j, state_j in enumerate(state):
check_variable_and_dtype(state_j, 'state_j[' + str(j) + ']',
['float32', 'float64'], 'rnn')
else:
check_variable_and_dtype(state, 'states[' + str(i) + ']',
['float32', 'float64'], 'rnn')
check_type(sequence_length, 'sequence_length', (Variable, type(None)),
'rnn')
def _maybe_copy(state, new_state, step_mask):
# TODO: use where_op
new_state = nn.elementwise_mul(
new_state, step_mask, axis=0) - nn.elementwise_mul(
state, (step_mask - 1), axis=0)
return new_state
def _transpose_batch_time(x):
return nn.transpose(x, [1, 0] + list(range(2, len(x.shape))))
def _switch_grad(x, stop=False):
x.stop_gradient = stop
return x
......@@ -582,6 +648,98 @@ def rnn(cell,
return (final_outputs, final_states)
def birnn(cell_fw,
cell_bw,
inputs,
initial_states=None,
sequence_length=None,
time_major=False,
**kwargs):
"""
birnn creates a bidirectional recurrent neural network specified by
RNNCell `cell_fw` and `cell_bw`, which performs :code:`cell.call()`
(for dygraph mode :code:`cell.forward`) repeatedly until reaches to
the maximum length of `inputs` and then concat the ouputs for both RNNs
along the last axis.
Arguments:
cell_fw(RNNCellBase): An instance of `RNNCellBase`.
cell_bw(RNNCellBase): An instance of `RNNCellBase`.
inputs(Tensor): the input sequences.
If time_major is True, the shape is
`[time_steps, batch_size, input_size]`
else the shape is `[batch_size, time_steps, input_size]`.
initial_states(tuple, optional): A tuple of initial states of
`cell_fw` and `cell_bw`.
If not provided, `cell.get_initial_states` would be called to
produce initial state for each cell. Defaults to None.
sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64
or int32. The valid lengths of input sequences. Defaults to None.
If `sequence_length` is not None, the inputs are treated as
padded sequences. In each input sequence, elements whose time step
index are not less than the valid length are treated as paddings.
time_major (bool): Whether the first dimension of the input means the
time steps. Defaults to False.
**kwargs: Additional keyword arguments to pass to `forward` of each cell.
Returns:
(outputs, final_states)
outputs (Tensor): the outputs of the bidirectional RNN. It is the
concatenation of the outputs from the forward RNN and backward
RNN along the last axis.
If time major is True, the shape is `[time_steps, batch_size, size]`,
else the shape is `[batch_size, time_steps, size]`, where size is
`cell_fw.hidden_size + cell_bw.hidden_size`.
final_states (tuple): A tuple of the final states of the forward
cell and backward cell.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
cell_fw = paddle.nn.LSTMCell(16, 32)
cell_bw = paddle.nn.LSTMCell(16, 32)
inputs = paddle.rand((4, 23, 16))
hf, cf = paddle.rand((4, 32)), paddle.rand((4, 32))
hb, cb = paddle.rand((4, 32)), paddle.rand((4, 32))
initial_states = ((hf, cf), (hb, cb))
outputs, final_states = paddle.nn.functional.birnn(
cell_fw, cell_bw, inputs, initial_states)
"""
if initial_states is None:
states_fw = cell_fw.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0)
states_bw = cell_fw.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0)
else:
states_fw, states_bw = initial_states
outputs_fw, states_fw = rnn(cell_fw,
inputs,
states_fw,
sequence_length,
time_major=time_major,
**kwargs)
outputs_bw, states_bw = rnn(cell_bw,
inputs,
states_bw,
sequence_length,
time_major=time_major,
is_reverse=True,
**kwargs)
outputs = map_structure(lambda x, y: tensor.concat([x, y], -1), outputs_fw,
outputs_bw)
final_states = (states_fw, states_bw)
return outputs, final_states
class Decoder(object):
"""
:api_attr: Static Graph
......
......@@ -542,6 +542,7 @@ endif()
add_subdirectory(sequence)
add_subdirectory(dygraph_to_static)
add_subdirectory(rnn)
if (WITH_MKLDNN)
add_subdirectory(mkldnn)
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
def convert_params_for_cell(np_cell, paddle_cell):
state = np_cell.parameters
for k, v in paddle_cell.named_parameters():
v.set_value(state[k])
def convert_params_for_cell_static(np_cell, paddle_cell, place):
state = np_cell.parameters
for k, v in paddle_cell.named_parameters():
scope = paddle.static.global_scope()
tensor = scope.find_var(v.name).get_tensor()
tensor.set(state[k], place)
def convert_params_for_net(np_net, paddle_net):
for np_layer, paddle_layer in zip(np_net, paddle_net):
if hasattr(np_layer, "cell"):
convert_params_for_cell(np_layer.cell, paddle_layer.cell)
else:
convert_params_for_cell(np_layer.cell_fw, paddle_layer.cell_fw)
convert_params_for_cell(np_layer.cell_bw, paddle_layer.cell_bw)
def convert_params_for_net_static(np_net, paddle_net, place):
for np_layer, paddle_layer in zip(np_net, paddle_net):
if hasattr(np_layer, "cell"):
convert_params_for_cell_static(np_layer.cell, paddle_layer.cell,
place)
else:
convert_params_for_cell_static(np_layer.cell_fw,
paddle_layer.cell_fw, place)
convert_params_for_cell_static(np_layer.cell_bw,
paddle_layer.cell_bw, place)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
class LayerMixin(object):
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class LayerListMixin(LayerMixin):
def __init__(self, layers=None):
self._layers = list(layers) if layers else []
def append(self, layer):
self._layers.append(layer)
def __iter__(self):
return iter(self._layers)
class SimpleRNNCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
if nonlinearity == 'tanh':
self.nonlinearity = np.tanh
else:
self.nonlinearity = lambda x: np.maximum(x, 0.)
self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, (
hidden_size, input_size)).astype('float64')
self.weight_hh = np.random.uniform(-std, std, (
hidden_size, hidden_size)).astype('float64')
self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh
if bias:
self.bias_ih = np.random.uniform(-std, std,
(hidden_size, )).astype('float64')
self.bias_hh = np.random.uniform(-std, std,
(hidden_size, )).astype('float64')
self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh
else:
self.bias_ih = None
self.bias_hh = None
def init_state(self, inputs):
batch_size = inputs.shape[0]
return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
def forward(self, inputs, hx=None):
if hx is None:
hx = self.init_state(inputs)
pre_h = hx
i2h = np.matmul(inputs, self.weight_ih.T)
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = np.matmul(pre_h, self.weight_hh.T)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self.nonlinearity(i2h + h2h)
return h, h
class GRUCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True):
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, (
3 * hidden_size, input_size)).astype('float64')
self.weight_hh = np.random.uniform(-std, std, (
3 * hidden_size, hidden_size)).astype('float64')
self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh
if bias:
self.bias_ih = np.random.uniform(-std, std, (
3 * hidden_size)).astype('float64')
self.bias_hh = np.random.uniform(-std, std, (
3 * hidden_size)).astype('float64')
self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh
else:
self.bias_ih = None
self.bias_hh = None
def init_state(self, inputs):
batch_size = inputs.shape[0]
return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
def forward(self, inputs, hx=None):
if hx is None:
hx = self.init_state(inputs)
pre_hidden = hx
x_gates = np.matmul(inputs, self.weight_ih.T)
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = np.matmul(pre_hidden, self.weight_hh.T)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = np.split(x_gates, 3, 1)
h_r, h_z, h_c = np.split(h_gates, 3, 1)
r = 1.0 / (1.0 + np.exp(-(x_r + h_r)))
z = 1.0 / (1.0 + np.exp(-(x_z + h_z)))
c = np.tanh(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
return h, h
class LSTMCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True):
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, (
4 * hidden_size, input_size)).astype('float64')
self.weight_hh = np.random.uniform(-std, std, (
4 * hidden_size, hidden_size)).astype('float64')
self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh
if bias:
self.bias_ih = np.random.uniform(-std, std, (
4 * hidden_size)).astype('float64')
self.bias_hh = np.random.uniform(-std, std, (
4 * hidden_size)).astype('float64')
self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh
else:
self.bias_ih = None
self.bias_hh = None
def init_state(self, inputs):
batch_size = inputs.shape[0]
init_h = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
init_c = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
return init_h, init_c
def forward(self, inputs, hx=None):
if hx is None:
hx = self.init_state(inputs)
pre_hidden, pre_cell = hx
gates = np.matmul(inputs, self.weight_ih.T)
if self.bias_ih is not None:
gates = gates + self.bias_ih
gates += np.matmul(pre_hidden, self.weight_hh.T)
if self.bias_hh is not None:
gates = gates + self.bias_hh
chunked_gates = np.split(gates, 4, -1)
i = 1.0 / (1.0 + np.exp(-chunked_gates[0]))
f = 1.0 / (1.0 + np.exp(-chunked_gates[1]))
o = 1.0 / (1.0 + np.exp(-chunked_gates[3]))
c = f * pre_cell + i * np.tanh(chunked_gates[2])
h = o * np.tanh(c)
return h, (h, c)
def sequence_mask(lengths, max_len=None):
if max_len is None:
max_len = np.max(lengths)
else:
assert max_len >= np.max(lengths)
return np.arange(max_len) < np.expand_dims(lengths, -1)
def update_state(mask, new, old):
if not isinstance(old, (tuple, list)):
return np.where(mask, new, old)
else:
return tuple(map(lambda x, y: np.where(mask, x, y), new, old))
def rnn(cell,
inputs,
initial_states,
sequence_length=None,
time_major=False,
is_reverse=False):
if not time_major:
inputs = np.transpose(inputs, [1, 0, 2])
if is_reverse:
inputs = np.flip(inputs, 0)
if sequence_length is None:
mask = None
else:
mask = np.transpose(sequence_mask(sequence_length), [1, 0])
mask = np.expand_dims(mask, -1)
if is_reverse:
mask = np.flip(mask, 0)
time_steps = inputs.shape[0]
state = initial_states
outputs = []
for t in range(time_steps):
x_t = inputs[t]
if mask is not None:
m_t = mask[t]
y, new_state = cell(x_t, state)
y = np.where(m_t, y, 0.)
outputs.append(y)
state = update_state(m_t, new_state, state)
else:
y, new_state = cell(x_t, state)
outputs.append(y)
state = new_state
outputs = np.stack(outputs)
final_state = state
if is_reverse:
outputs = np.flip(outputs, 0)
if not time_major:
outputs = np.transpose(outputs, [1, 0, 2])
return outputs, final_state
def birnn(cell_fw,
cell_bw,
inputs,
initial_states,
sequence_length=None,
time_major=False):
states_fw, states_bw = initial_states
outputs_fw, states_fw = rnn(cell_fw,
inputs,
states_fw,
sequence_length,
time_major=time_major)
outputs_bw, states_bw = rnn(cell_bw,
inputs,
states_bw,
sequence_length,
time_major=time_major,
is_reverse=True)
outputs = np.concatenate((outputs_fw, outputs_bw), -1)
final_states = (states_fw, states_bw)
return outputs, final_states
def flatten(nested):
return list(_flatten(nested))
def _flatten(nested):
for item in nested:
if isinstance(item, (list, tuple)):
for subitem in _flatten(item):
yield subitem
else:
yield item
def unstack(array, axis=0):
num = array.shape[axis]
sub_arrays = np.split(array, num, axis)
return [np.squeeze(sub_array, axis) for sub_array in sub_arrays]
def dropout(array, p=0.5):
if p == 0.0:
return array
mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype)
return array * (mask / (1 - p))
def split_states(states, bidirectional=False, state_components=1):
if state_components == 1:
states = unstack(states)
if not bidirectional:
return states
else:
return list(zip(states[::2], states[1::2]))
else:
assert len(states) == state_components
states = tuple([unstack(item) for item in states])
if not bidirectional:
return list(zip(*states))
else:
states = list(zip(*states))
return list(zip(states[::2], states[1::2]))
def concat_states(states, bidirectional=False, state_components=1):
if state_components == 1:
return np.stack(flatten(states))
else:
states = flatten(states)
componnets = []
for i in range(state_components):
componnets.append(states[i::state_components])
return [np.stack(item) for item in componnets]
class RNN(LayerMixin):
def __init__(self, cell, is_reverse=False, time_major=False):
super(RNN, self).__init__()
self.cell = cell
if not hasattr(self.cell, "call"):
# for non-dygraph mode, `rnn` api uses cell.call
self.cell.call = self.cell.forward
self.is_reverse = is_reverse
self.time_major = time_major
def forward(self, inputs, initial_states=None, sequence_length=None):
final_outputs, final_states = rnn(self.cell,
inputs,
initial_states=initial_states,
sequence_length=sequence_length,
time_major=self.time_major,
is_reverse=self.is_reverse)
return final_outputs, final_states
class BiRNN(LayerMixin):
def __init__(self, cell_fw, cell_bw, time_major=False):
super(BiRNN, self).__init__()
self.cell_fw = cell_fw
self.cell_bw = cell_bw
self.time_major = time_major
def forward(self,
inputs,
initial_states=None,
sequence_length=None,
**kwargs):
if isinstance(initial_states, (list, tuple)):
assert len(initial_states) == 2, \
"length of initial_states should be 2 when it is a list/tuple"
else:
initial_states = [initial_states, initial_states]
outputs, final_states = birnn(self.cell_fw, self.cell_bw, inputs,
initial_states, sequence_length,
self.time_major)
return outputs, final_states
class RNNMixin(LayerListMixin):
def forward(self, inputs, initial_states=None, sequence_length=None):
batch_index = 1 if self.time_major else 0
batch_size = inputs.shape[batch_index]
dtype = inputs.dtype
if initial_states is None:
state_shape = (self.num_layers * self.num_directions, batch_size,
self.hidden_size)
if self.state_components == 1:
initial_states = np.zeros(state_shape, dtype)
else:
initial_states = tuple([
np.zeros(state_shape, dtype)
for _ in range(self.state_components)
])
states = split_states(initial_states, self.num_directions == 2,
self.state_components)
final_states = []
for i, rnn_layer in enumerate(self):
if i > 0:
inputs = dropout(inputs, self.dropout)
outputs, final_state = rnn_layer(inputs, states[i], sequence_length)
final_states.append(final_state)
inputs = outputs
final_states = concat_states(final_states, self.num_directions == 2,
self.state_components)
return outputs, final_states
class SimpleRNN(RNNMixin):
def __init__(self,
input_size,
hidden_size,
num_layers=1,
nonlinearity="tanh",
direction="forward",
dropout=0.,
time_major=False):
super(SimpleRNN, self).__init__()
if direction in ["forward", "backward"]:
is_reverse = direction == "backward"
cell = SimpleRNNCell(input_size, hidden_size, nonlinearity)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = SimpleRNNCell(hidden_size, hidden_size, nonlinearity)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = SimpleRNNCell(input_size, hidden_size, nonlinearity)
cell_bw = SimpleRNNCell(input_size, hidden_size, nonlinearity)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = SimpleRNNCell(2 * hidden_size, hidden_size,
nonlinearity)
cell_bw = SimpleRNNCell(2 * hidden_size, hidden_size,
nonlinearity)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
"direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction))
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1
self.time_major = time_major
self.num_layers = num_layers
self.state_components = 1
class LSTM(RNNMixin):
def __init__(self,
input_size,
hidden_size,
num_layers=1,
direction="forward",
dropout=0.,
time_major=False):
super(LSTM, self).__init__()
if direction in ["forward", "backward"]:
is_reverse = direction == "backward"
cell = LSTMCell(input_size, hidden_size)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = LSTMCell(hidden_size, hidden_size)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = LSTMCell(input_size, hidden_size)
cell_bw = LSTMCell(input_size, hidden_size)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = LSTMCell(2 * hidden_size, hidden_size)
cell_bw = LSTMCell(2 * hidden_size, hidden_size)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
"direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction))
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1
self.time_major = time_major
self.num_layers = num_layers
self.state_components = 2
class GRU(RNNMixin):
def __init__(self,
input_size,
hidden_size,
num_layers=1,
direction="forward",
dropout=0.,
time_major=False):
super(GRU, self).__init__()
if direction in ["forward", "backward"]:
is_reverse = direction == "backward"
cell = GRUCell(input_size, hidden_size)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = GRUCell(hidden_size, hidden_size)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = GRUCell(input_size, hidden_size)
cell_bw = GRUCell(input_size, hidden_size)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = GRUCell(2 * hidden_size, hidden_size)
cell_bw = GRUCell(2 * hidden_size, hidden_size)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
"direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction))
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1
self.time_major = time_major
self.num_layers = num_layers
self.state_components = 1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
paddle.framework.set_default_dtype("float64")
import numpy as np
import unittest
from rnn_numpy import SimpleRNNCell, LSTMCell, GRUCell
from convert import convert_params_for_cell
class TestSimpleRNNCell(unittest.TestCase):
def __init__(self, bias=True, place="cpu"):
super(TestSimpleRNNCell, self).__init__(methodName="runTest")
self.bias = bias
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
paddle.disable_static(self.place)
rnn1 = SimpleRNNCell(16, 32, bias=self.bias)
rnn2 = paddle.nn.SimpleRNNCell(
16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias)
convert_params_for_cell(rnn1, rnn2)
self.rnn1 = rnn1
self.rnn2 = rnn2
def test_with_initial_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(4, 16)
prev_h = np.random.randn(4, 32)
y1, h1 = rnn1(x, prev_h)
y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h))
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(4, 16)
y1, h1 = rnn1(x)
y2, h2 = rnn2(paddle.to_variable(x))
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
class TestGRUCell(unittest.TestCase):
def __init__(self, bias=True, place="cpu"):
super(TestGRUCell, self).__init__(methodName="runTest")
self.bias = bias
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
paddle.disable_static(self.place)
rnn1 = GRUCell(16, 32, bias=self.bias)
rnn2 = paddle.nn.GRUCell(
16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias)
convert_params_for_cell(rnn1, rnn2)
self.rnn1 = rnn1
self.rnn2 = rnn2
def test_with_initial_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(4, 16)
prev_h = np.random.randn(4, 32)
y1, h1 = rnn1(x, prev_h)
y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h))
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(4, 16)
y1, h1 = rnn1(x)
y2, h2 = rnn2(paddle.to_variable(x))
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
class TestLSTMCell(unittest.TestCase):
def __init__(self, bias=True, place="cpu"):
super(TestLSTMCell, self).__init__(methodName="runTest")
self.bias = bias
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
rnn1 = LSTMCell(16, 32, bias=self.bias)
rnn2 = paddle.nn.LSTMCell(
16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias)
convert_params_for_cell(rnn1, rnn2)
self.rnn1 = rnn1
self.rnn2 = rnn2
def test_with_initial_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(4, 16)
prev_h = np.random.randn(4, 32)
prev_c = np.random.randn(4, 32)
y1, (h1, c1) = rnn1(x, (prev_h, prev_c))
y2, (h2, c2) = rnn2(
paddle.to_variable(x),
(paddle.to_variable(prev_h), paddle.to_variable(prev_c)))
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(4, 16)
y1, (h1, c1) = rnn1(x)
y2, (h2, c2) = rnn2(paddle.to_variable(x))
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \
else ["cpu"]
for bias in [True, False]:
for device in devices:
for test_class in [TestSimpleRNNCell, TestGRUCell, TestLSTMCell]:
suite.addTest(test_class(bias, device))
return suite
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
paddle.framework.set_default_dtype("float64")
import numpy as np
import unittest
from convert import convert_params_for_cell_static
from rnn_numpy import SimpleRNNCell, LSTMCell, GRUCell
class TestSimpleRNNCell(unittest.TestCase):
def __init__(self, bias=True, place="cpu"):
super(TestSimpleRNNCell, self).__init__(methodName="runTest")
self.bias = bias
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
rnn1 = SimpleRNNCell(16, 32, bias=self.bias)
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
rnn2 = paddle.nn.SimpleRNNCell(
16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias)
place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
exe.run(sp)
convert_params_for_cell_static(rnn1, rnn2, place)
self.mp = mp
self.sp = sp
self.rnn1 = rnn1
self.rnn2 = rnn2
self.executor = exe
self.scope = scope
def test_with_initial_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(4, 16)
prev_h = np.random.randn(4, 32)
y1, h1 = rnn1(x, prev_h)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, 16],
dtype=paddle.framework.get_default_dtype())
init_h = paddle.data(
"init_h", [-1, 32],
dtype=paddle.framework.get_default_dtype())
y, h = rnn2(x_data, init_h)
feed_dict = {x_data.name: x, init_h.name: prev_h}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h])
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(4, 16)
y1, h1 = rnn1(x)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, 16],
dtype=paddle.framework.get_default_dtype())
y, h = rnn2(x_data)
feed_dict = {x_data.name: x}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp,
feed=feed_dict,
fetch_list=[y, h],
use_prune=True)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
class TestGRUCell(unittest.TestCase):
def __init__(self, bias=True, place="cpu"):
super(TestGRUCell, self).__init__(methodName="runTest")
self.bias = bias
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
rnn1 = GRUCell(16, 32, bias=self.bias)
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
rnn2 = paddle.nn.GRUCell(
16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias)
place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
exe.run(sp)
convert_params_for_cell_static(rnn1, rnn2, place)
self.mp = mp
self.sp = sp
self.rnn1 = rnn1
self.rnn2 = rnn2
self.place = place
self.executor = exe
self.scope = scope
def test_with_initial_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(4, 16)
prev_h = np.random.randn(4, 32)
y1, h1 = rnn1(x, prev_h)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, 16],
dtype=paddle.framework.get_default_dtype())
init_h = paddle.data(
"init_h", [-1, 32],
dtype=paddle.framework.get_default_dtype())
y, h = rnn2(x_data, init_h)
feed_dict = {x_data.name: x, init_h.name: prev_h}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h])
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(4, 16)
y1, h1 = rnn1(x)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, 16],
dtype=paddle.framework.get_default_dtype())
y, h = rnn2(x_data)
feed_dict = {x_data.name: x}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp,
feed=feed_dict,
fetch_list=[y, h],
use_prune=True)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
class TestLSTMCell(unittest.TestCase):
def __init__(self, bias=True, place="cpu"):
super(TestLSTMCell, self).__init__(methodName="runTest")
self.bias = bias
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
rnn1 = LSTMCell(16, 32, bias=self.bias)
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
rnn2 = paddle.nn.LSTMCell(
16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias)
place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
exe.run(sp)
convert_params_for_cell_static(rnn1, rnn2, place)
self.mp = mp
self.sp = sp
self.rnn1 = rnn1
self.rnn2 = rnn2
self.place = place
self.executor = exe
self.scope = scope
def test_with_initial_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(4, 16)
prev_h = np.random.randn(4, 32)
prev_c = np.random.randn(4, 32)
y1, (h1, c1) = rnn1(x, (prev_h, prev_c))
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, 16],
dtype=paddle.framework.get_default_dtype())
init_h = paddle.data(
"init_h", [-1, 32],
dtype=paddle.framework.get_default_dtype())
init_c = paddle.data(
"init_c", [-1, 32],
dtype=paddle.framework.get_default_dtype())
y, (h, c) = rnn2(x_data, (init_h, init_c))
feed_dict = {x_data.name: x, init_h.name: prev_h, init_c.name: prev_c}
with paddle.static.scope_guard(scope):
y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c])
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(4, 16)
y1, (h1, c1) = rnn1(x)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, 16],
dtype=paddle.framework.get_default_dtype())
y, (h, c) = rnn2(x_data)
feed_dict = {x_data.name: x}
with paddle.static.scope_guard(scope):
y2, h2, c2 = exe.run(mp,
feed=feed_dict,
fetch_list=[y, h, c],
use_prune=True)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \
else ["cpu"]
for bias in [True, False]:
for device in devices:
for test_class in [TestSimpleRNNCell, TestGRUCell, TestLSTMCell]:
suite.addTest(test_class(bias, device))
return suite
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
paddle.set_default_dtype("float64")
from paddle.fluid.layers import sequence_mask
import numpy as np
import unittest
from convert import convert_params_for_net
from rnn_numpy import SimpleRNN, LSTM, GRU
class TestSimpleRNN(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestSimpleRNN, self).__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
paddle.disable_static(self.place)
rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction)
rnn2 = paddle.nn.SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction)
convert_params_for_net(rnn1, rnn2)
self.rnn1 = rnn1
self.rnn2 = rnn2
def test_with_initial_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(2 * self.num_directions, 4, 32)
y1, h1 = rnn1(x, prev_h)
y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
y1, h1 = rnn1(x)
y2, h2 = rnn2(paddle.to_variable(x))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_input_lengths(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
sequence_length = np.array([12, 10, 9, 8], dtype=np.int64)
y1, h1 = rnn1(x, sequence_length=sequence_length)
seq_len = paddle.to_variable(sequence_length)
mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype())
if self.time_major:
mask = paddle.transpose(mask, [1, 0])
y2, h2 = rnn2(paddle.to_variable(x), sequence_length=seq_len)
y2 = paddle.multiply(y2, mask, axis=0)
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
class TestGRU(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestGRU, self).__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
paddle.disable_static(self.place)
rnn1 = GRU(16,
32,
2,
time_major=self.time_major,
direction=self.direction)
rnn2 = paddle.nn.GRU(16,
32,
2,
time_major=self.time_major,
direction=self.direction)
convert_params_for_net(rnn1, rnn2)
self.rnn1 = rnn1
self.rnn2 = rnn2
def test_with_initial_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(2 * self.num_directions, 4, 32)
y1, h1 = rnn1(x, prev_h)
y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
y1, h1 = rnn1(x)
y2, h2 = rnn2(paddle.to_variable(x))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_input_lengths(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
sequence_length = np.array([12, 10, 9, 8], dtype=np.int64)
y1, h1 = rnn1(x, sequence_length=sequence_length)
seq_len = paddle.to_variable(sequence_length)
mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype())
if self.time_major:
mask = paddle.transpose(mask, [1, 0])
y2, h2 = rnn2(paddle.to_variable(x), sequence_length=seq_len)
y2 = paddle.multiply(y2, mask, axis=0)
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
class TestLSTM(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestLSTM, self).__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
paddle.disable_static(self.place)
rnn1 = LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction)
rnn2 = paddle.nn.LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction)
convert_params_for_net(rnn1, rnn2)
self.rnn1 = rnn1
self.rnn2 = rnn2
def test_with_initial_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(2 * self.num_directions, 4, 32)
prev_c = np.random.randn(2 * self.num_directions, 4, 32)
y1, (h1, c1) = rnn1(x, (prev_h, prev_c))
y2, (h2, c2) = rnn2(
paddle.to_variable(x),
(paddle.to_variable(prev_h), paddle.to_variable(prev_c)))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
y1, (h1, c1) = rnn1(x)
y2, (h2, c2) = rnn2(paddle.to_variable(x))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_input_lengths(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
sequence_length = np.array([12, 10, 9, 8], dtype=np.int64)
y1, (h1, c1) = rnn1(x, sequence_length=sequence_length)
seq_len = paddle.to_variable(sequence_length)
mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype())
if self.time_major:
mask = paddle.transpose(mask, [1, 0])
y2, (h2, c2) = rnn2(paddle.to_variable(x), sequence_length=seq_len)
y2 = paddle.multiply(y2, mask, axis=0)
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \
else ["cpu"]
for direction in ["forward", "backward", "bidirectional"]:
for time_major in [True, False]:
for device in devices:
for test_class in [TestSimpleRNN, TestLSTM, TestGRU]:
suite.addTest(test_class(time_major, direction, device))
return suite
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
paddle.set_default_dtype("float64")
from paddle.fluid.layers import sequence_mask
import numpy as np
import unittest
from convert import convert_params_for_net_static
from rnn_numpy import SimpleRNN, LSTM, GRU
class TestSimpleRNN(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestSimpleRNN, self).__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction)
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
rnn2 = paddle.nn.SimpleRNN(
16,
32,
2,
time_major=self.time_major,
direction=self.direction)
place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
exe.run(sp)
convert_params_for_net_static(rnn1, rnn2, place)
self.mp = mp
self.sp = sp
self.rnn1 = rnn1
self.rnn2 = rnn2
self.place = place
self.executor = exe
self.scope = scope
def test_with_initial_state(self):
mp = self.mp.clone().clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(2 * self.num_directions, 4, 32)
y1, h1 = rnn1(x, prev_h)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
init_h = paddle.data(
"init_h", [2 * self.num_directions, -1, 32],
dtype=paddle.framework.get_default_dtype())
y, h = rnn2(x_data, init_h)
feed_dict = {x_data.name: x, init_h.name: prev_h}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
y1, h1 = rnn1(x)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
y, h = rnn2(x_data)
feed_dict = {x_data.name: x}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def test_with_input_lengths(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
sequence_length = np.array([12, 10, 9, 8], dtype=np.int64)
y1, h1 = rnn1(x, sequence_length=sequence_length)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
seq_len = paddle.data("seq_len", [-1], dtype="int64")
mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype())
if self.time_major:
mask = paddle.transpose(mask, [1, 0])
y, h = rnn2(x_data, sequence_length=seq_len)
y = paddle.multiply(y, mask, axis=0)
feed_dict = {x_data.name: x, seq_len.name: sequence_length}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
class TestGRU(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestGRU, self).__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
rnn1 = GRU(16,
32,
2,
time_major=self.time_major,
direction=self.direction)
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
rnn2 = paddle.nn.GRU(16,
32,
2,
time_major=self.time_major,
direction=self.direction)
place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
exe.run(sp)
convert_params_for_net_static(rnn1, rnn2, place)
self.mp = mp
self.sp = sp
self.rnn1 = rnn1
self.rnn2 = rnn2
self.place = place
self.executor = exe
self.scope = scope
def test_with_initial_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(2 * self.num_directions, 4, 32)
y1, h1 = rnn1(x, prev_h)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
init_h = paddle.data(
"init_h", [2 * self.num_directions, -1, 32],
dtype=paddle.framework.get_default_dtype())
y, h = rnn2(x_data, init_h)
feed_dict = {x_data.name: x, init_h.name: prev_h}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
y1, h1 = rnn1(x)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
y, h = rnn2(x_data)
feed_dict = {x_data.name: x}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def test_with_input_lengths(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
sequence_length = np.array([12, 10, 9, 8], dtype=np.int64)
y1, h1 = rnn1(x, sequence_length=sequence_length)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
seq_len = paddle.data("seq_len", [-1], dtype="int64")
mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype())
if self.time_major:
mask = paddle.transpose(mask, [1, 0])
y, h = rnn2(x_data, sequence_length=seq_len)
y = paddle.multiply(y, mask, axis=0)
feed_dict = {x_data.name: x, seq_len.name: sequence_length}
with paddle.static.scope_guard(scope):
y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
class TestLSTM(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestLSTM, self).__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
rnn1 = LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction)
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
rnn2 = paddle.nn.LSTM(
16,
32,
2,
time_major=self.time_major,
direction=self.direction)
place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
exe.run(sp)
convert_params_for_net_static(rnn1, rnn2, place)
self.mp = mp
self.sp = sp
self.rnn1 = rnn1
self.rnn2 = rnn2
self.place = place
self.executor = exe
self.scope = scope
def test_with_initial_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(2 * self.num_directions, 4, 32)
prev_c = np.random.randn(2 * self.num_directions, 4, 32)
y1, (h1, c1) = rnn1(x, (prev_h, prev_c))
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
init_h = paddle.data(
"init_h", [2 * self.num_directions, -1, 32],
dtype=paddle.framework.get_default_dtype())
init_c = paddle.data(
"init_c", [2 * self.num_directions, -1, 32],
dtype=paddle.framework.get_default_dtype())
y, (h, c) = rnn2(x_data, (init_h, init_c))
feed_dict = {x_data.name: x, init_h.name: prev_h, init_c.name: prev_c}
with paddle.static.scope_guard(scope):
y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
y1, (h1, c1) = rnn1(x)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
y, (h, c) = rnn2(x_data)
feed_dict = {x_data.name: x}
with paddle.static.scope_guard(scope):
y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5)
def test_with_input_lengths(self):
mp = self.mp.clone()
sp = self.sp
rnn1 = self.rnn1
rnn2 = self.rnn2
exe = self.executor
scope = self.scope
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
sequence_length = np.array([12, 10, 9, 8], dtype=np.int64)
y1, (h1, c1) = rnn1(x, sequence_length=sequence_length)
with paddle.fluid.unique_name.guard():
with paddle.static.program_guard(mp, sp):
x_data = paddle.data(
"input", [-1, -1, 16],
dtype=paddle.framework.get_default_dtype())
seq_len = paddle.data("seq_len", [-1], dtype="int64")
mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype())
if self.time_major:
mask = paddle.transpose(mask, [1, 0])
y, (h, c) = rnn2(x_data, sequence_length=seq_len)
y = paddle.multiply(y, mask, axis=0)
feed_dict = {x_data.name: x, seq_len.name: sequence_length}
with paddle.static.scope_guard(scope):
y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c])
np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \
else ["cpu"]
for direction in ["forward", "backward", "bidirectional"]:
for time_major in [True, False]:
for device in devices:
for test_class in [TestSimpleRNN, TestLSTM, TestGRU]:
suite.addTest(test_class(time_major, direction, device))
return suite
......@@ -18,6 +18,7 @@
from .layer import norm
from .functional import extension
from .layer import common
from .layer import rnn
from .utils import weight_norm_hook
from . import initializer
......@@ -26,6 +27,7 @@ __all__ = []
__all__ += norm.__all__
__all__ += extension.__all__
__all__ += common.__all__
__all__ += rnn.__all__
__all__ += weight_norm_hook.__all__
# TODO: define alias in nn directory
......@@ -136,6 +138,7 @@ from .layer.norm import InstanceNorm3d #DEFINE_ALIAS
from .layer.norm import BatchNorm1d #DEFINE_ALIAS
from .layer.norm import BatchNorm2d #DEFINE_ALIAS
from .layer.norm import BatchNorm3d #DEFINE_ALIAS
from .layer.rnn import *
# from .layer.rnn import RNNCell #DEFINE_ALIAS
# from .layer.rnn import GRUCell #DEFINE_ALIAS
# from .layer.rnn import LSTMCell #DEFINE_ALIAS
......
......@@ -177,6 +177,8 @@ from .pooling import pool2d #DEFINE_ALIAS
from .pooling import pool3d #DEFINE_ALIAS
from .pooling import adaptive_pool2d #DEFINE_ALIAS
from .pooling import adaptive_pool3d #DEFINE_ALIAS
from .rnn import rnn #DEFINE_ALIAS
from .rnn import birnn #DEFINE_ALIAS
from .pooling import avg_pool2d #DEFINE_ALIAS
from .pooling import max_pool2d #DEFINE_ALIAS
from .pooling import avg_pool3d #DEFINE_ALIAS
......
......@@ -12,10 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define function of recurrent neural network
from paddle.fluid.layers.rnn import rnn, birnn
__all__ = [
# 'gru_unit',
# 'lstm',
# 'lstm_unit'
]
__all__ = ['rnn', 'birnn']
......@@ -20,6 +20,7 @@ from . import conv
from . import extension
from . import activation
from . import norm
from . import rnn
from . import vision
from . import distance
from . import transformer
......@@ -30,6 +31,7 @@ from .conv import *
from .extension import *
from .activation import *
from .norm import *
from .rnn import *
from .vision import *
from .transformer import *
......
此差异已折叠。
......@@ -148,7 +148,20 @@
"Callback.on_eval_batch_end",
"Callback.on_test_batch_begin",
"Callback.on_test_batch_end",
"Model.prepare"
"Model.prepare",
"SimpleRNNCell",
"SimpleRNNCell.forward",
"LSTMCell",
"LSTMCell.forward",
"GRUCell",
"GRUCell.forward",
"SimpleRNN",
"GRU",
"LSTM",
"RNN",
"BiRNN",
"RNNCellBase",
"RNNCellBase.get_initial_states"
],
"wlist_no_op_pass":[
"gelu",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册