未验证 提交 acee3dd3 编写于 作者: L lugimzzz 提交者: GitHub

[fluid clean] remove 4 fluid.layers api and imigrate 2 fluid.layer api (#48972)

* fluid clean layer

* docs
上级 b06a5946
此差异已折叠。
......@@ -2179,26 +2179,6 @@ class TestBook(LayerTest):
x, kernel_size=[5, 3], stride=[1, 2], padding=(2, 1)
)
def make_lstm_unit(self):
with program_guard(
fluid.default_main_program(), fluid.default_startup_program()
):
x_t_data = self._get_data(
name='x_t_data', shape=[10, 10], dtype='float32'
)
x_t = layers.fc(input=x_t_data, size=10)
prev_hidden_data = self._get_data(
name='prev_hidden_data', shape=[10, 30], dtype='float32'
)
prev_hidden = layers.fc(input=prev_hidden_data, size=30)
prev_cell_data = self._get_data(
name='prev_cell', shape=[10, 30], dtype='float32'
)
prev_cell = layers.fc(input=prev_cell_data, size=30)
return layers.lstm_unit(
x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell
)
def make_softmax(self):
with program_guard(
fluid.default_main_program(), fluid.default_startup_program()
......
......@@ -17,10 +17,6 @@ import unittest
import numpy as np
from op_test import OpTest
from paddle import fluid
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.layers import lstm_unit
def sigmoid_np(x):
return 1.0 / (1.0 + np.exp(-x))
......@@ -30,79 +26,6 @@ def tanh_np(x):
return 2 * sigmoid_np(2.0 * x) - 1.0
class LstmUnitTestError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
batch_size, dict_dim, emb_dim, hidden_dim = 32, 128, 64, 512
data = fluid.data(
name='step_data', shape=[batch_size], dtype='int64'
)
inputs = fluid.embedding(input=data, size=[dict_dim, emb_dim])
pre_hidden = fluid.data(
name='pre_hidden',
shape=[batch_size, hidden_dim],
dtype='float32',
)
pre_cell = fluid.data(
name='pre_cell', shape=[batch_size, hidden_dim], dtype='float32'
)
np_input = np.random.uniform(
-0.1, 0.1, (batch_size, emb_dim)
).astype('float64')
np_pre_hidden = np.random.uniform(
-0.1, 0.1, (batch_size, hidden_dim)
).astype('float64')
np_pre_cell = np.random.uniform(
-0.1, 0.1, (batch_size, hidden_dim)
).astype('float64')
def test_input_Variable():
lstm_unit(np_input, pre_hidden, pre_cell)
self.assertRaises(TypeError, test_input_Variable)
def test_pre_hidden_Variable():
lstm_unit(inputs, np_pre_hidden, pre_cell)
self.assertRaises(TypeError, test_pre_hidden_Variable)
def test_pre_cell_Variable():
lstm_unit(inputs, pre_hidden, np_pre_cell)
self.assertRaises(TypeError, test_pre_cell_Variable)
def test_input_type():
error_input = fluid.data(
name='error_input',
shape=[batch_size, emb_dim],
dtype='int32',
)
lstm_unit(error_input, pre_hidden, pre_cell)
self.assertRaises(TypeError, test_input_type)
def test_pre_hidden_type():
error_pre_hidden = fluid.data(
name='error_pre_hidden',
shape=[batch_size, hidden_dim],
dtype='int32',
)
lstm_unit(inputs, error_pre_hidden, pre_cell)
self.assertRaises(TypeError, test_pre_hidden_type)
def test_pre_cell_type():
error_pre_cell = fluid.data(
name='error_pre_cell',
shape=[batch_size, hidden_dim],
dtype='int32',
)
lstm_unit(inputs, pre_hidden, error_pre_cell)
self.assertRaises(TypeError, test_pre_cell_type)
class LstmUnitTest(OpTest):
def setUp(self):
self.op_type = "lstm_unit"
......
......@@ -19,12 +19,10 @@ import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
import paddle.nn as nn
from paddle import Model, set_device
from paddle.fluid.dygraph import Layer
from paddle.fluid.executor import Executor
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import BeamSearchDecoder, dynamic_decode
from paddle.static import InputSpec as Input
......@@ -32,257 +30,6 @@ from paddle.static import InputSpec as Input
paddle.enable_static()
class EncoderCell(layers.RNNCell):
def __init__(self, num_layers, hidden_size, dropout_prob=0.0):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
self.lstm_cells = [
layers.LSTMCell(hidden_size) for i in range(num_layers)
]
def call(self, step_input, states):
new_states = []
for i in range(self.num_layers):
out, new_state = self.lstm_cells[i](step_input, states[i])
step_input = (
layers.dropout(out, self.dropout_prob)
if self.dropout_prob > 0
else out
)
new_states.append(new_state)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.lstm_cells]
class DecoderCell(layers.RNNCell):
def __init__(self, num_layers, hidden_size, dropout_prob=0.0):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
self.lstm_cells = [
layers.LSTMCell(hidden_size) for i in range(num_layers)
]
def attention(self, hidden, encoder_output, encoder_padding_mask):
query = layers.fc(
hidden, size=encoder_output.shape[-1], bias_attr=False
)
attn_scores = paddle.matmul(
layers.unsqueeze(query, [1]), encoder_output, transpose_y=True
)
if encoder_padding_mask is not None:
attn_scores = paddle.add(attn_scores, encoder_padding_mask)
attn_scores = paddle.nn.functional.softmax(attn_scores)
attn_out = paddle.squeeze(
paddle.matmul(attn_scores, encoder_output), [1]
)
attn_out = layers.concat([attn_out, hidden], 1)
attn_out = layers.fc(attn_out, size=self.hidden_size, bias_attr=False)
return attn_out
def call(
self, step_input, states, encoder_output, encoder_padding_mask=None
):
lstm_states, input_feed = states
new_lstm_states = []
step_input = layers.concat([step_input, input_feed], 1)
for i in range(self.num_layers):
out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
step_input = (
layers.dropout(out, self.dropout_prob)
if self.dropout_prob > 0
else out
)
new_lstm_states.append(new_lstm_state)
out = self.attention(step_input, encoder_output, encoder_padding_mask)
return out, [new_lstm_states, out]
class Encoder:
def __init__(self, num_layers, hidden_size, dropout_prob=0.0):
self.encoder_cell = EncoderCell(num_layers, hidden_size, dropout_prob)
def __call__(self, src_emb, src_sequence_length):
encoder_output, encoder_final_state = layers.rnn(
cell=self.encoder_cell,
inputs=src_emb,
sequence_length=src_sequence_length,
is_reverse=False,
)
return encoder_output, encoder_final_state
class Decoder:
def __init__(
self,
num_layers,
hidden_size,
dropout_prob,
decoding_strategy="infer_sample",
max_decoding_length=20,
):
self.decoder_cell = DecoderCell(num_layers, hidden_size, dropout_prob)
self.decoding_strategy = decoding_strategy
self.max_decoding_length = (
None
if (self.decoding_strategy == "train_greedy")
else max_decoding_length
)
def __call__(
self,
decoder_initial_states,
encoder_output,
encoder_padding_mask,
**kwargs
):
output_layer = kwargs.pop("output_layer", None)
beam_size = kwargs.get("beam_size", 4)
encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_output, beam_size
)
encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_padding_mask, beam_size
)
decoder = BeamSearchDecoder(
cell=self.decoder_cell, output_fn=output_layer, **kwargs
)
(
decoder_output,
decoder_final_state,
dec_seq_lengths,
) = layers.dynamic_decode(
decoder,
inits=decoder_initial_states,
max_step_num=self.max_decoding_length,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask,
impute_finished=False # for test coverage
if self.decoding_strategy == "beam_search"
else True,
is_test=True if self.decoding_strategy == "beam_search" else False,
return_length=True,
)
return decoder_output, decoder_final_state, dec_seq_lengths
class Seq2SeqModel:
"""Seq2Seq model: RNN encoder-decoder with attention"""
def __init__(
self,
num_layers,
hidden_size,
dropout_prob,
src_vocab_size,
trg_vocab_size,
start_token,
end_token,
decoding_strategy="infer_sample",
max_decoding_length=20,
beam_size=4,
):
self.start_token, self.end_token = start_token, end_token
self.max_decoding_length, self.beam_size = (
max_decoding_length,
beam_size,
)
self.src_embeder = paddle.nn.Embedding(
src_vocab_size,
hidden_size,
weight_attr=fluid.ParamAttr(name="source_embedding"),
)
self.trg_embeder = paddle.nn.Embedding(
trg_vocab_size,
hidden_size,
weight_attr=fluid.ParamAttr(name="target_embedding"),
)
self.encoder = Encoder(num_layers, hidden_size, dropout_prob)
self.decoder = Decoder(
num_layers,
hidden_size,
dropout_prob,
decoding_strategy,
max_decoding_length,
)
self.output_layer = lambda x: layers.fc(
x,
size=trg_vocab_size,
num_flatten_dims=len(x.shape) - 1,
param_attr=fluid.ParamAttr(),
bias_attr=False,
)
def __call__(self, src, src_length, trg=None, trg_length=None):
# encoder
encoder_output, encoder_final_state = self.encoder(
self.src_embeder(src), src_length
)
decoder_initial_states = [
encoder_final_state,
self.decoder.decoder_cell.get_initial_states(
batch_ref=encoder_output, shape=[encoder_output.shape[-1]]
),
]
src_mask = layers.sequence_mask(
src_length, maxlen=paddle.shape(src)[1], dtype="float32"
)
encoder_padding_mask = (src_mask - 1.0) * 1e9
encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1])
# decoder
decoder_kwargs = (
{
"inputs": self.trg_embeder(trg),
"sequence_length": trg_length,
}
if self.decoder.decoding_strategy == "train_greedy"
else (
{
"embedding_fn": self.trg_embeder,
"beam_size": self.beam_size,
"start_token": self.start_token,
"end_token": self.end_token,
}
if self.decoder.decoding_strategy == "beam_search"
else {
"embedding_fn": self.trg_embeder,
"start_tokens": layers.fill_constant_batch_size_like(
input=encoder_output,
shape=[-1],
dtype=src.dtype,
value=self.start_token,
),
"end_token": self.end_token,
}
)
)
decoder_kwargs["output_layer"] = self.output_layer
(decoder_output, decoder_final_state, dec_seq_lengths) = self.decoder(
decoder_initial_states,
encoder_output,
encoder_padding_mask,
**decoder_kwargs
)
if self.decoder.decoding_strategy == "beam_search": # for inference
return decoder_output
logits, samples, sample_length = (
decoder_output.cell_outputs,
decoder_output.sample_ids,
dec_seq_lengths,
)
probs = paddle.nn.functional.softmax(logits)
return probs, samples, sample_length
class PolicyGradient:
"""policy gradient"""
......@@ -477,91 +224,6 @@ class SeqPGAgent:
return results
class TestDynamicDecode(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.model_hparams = {
"num_layers": 2,
"hidden_size": 32,
"dropout_prob": 0.1,
"src_vocab_size": 100,
"trg_vocab_size": 100,
"start_token": 0,
"end_token": 1,
"decoding_strategy": "infer_greedy",
"max_decoding_length": 10,
}
self.iter_num = iter_num = 2
self.batch_size = batch_size = 4
src_seq_len = 10
trg_seq_len = 12
self.data = {
"src": np.random.randint(
2,
self.model_hparams["src_vocab_size"],
(iter_num * batch_size, src_seq_len),
).astype("int64"),
"src_sequence_length": np.random.randint(
1, src_seq_len, (iter_num * batch_size,)
).astype("int64"),
"trg": np.random.randint(
2,
self.model_hparams["src_vocab_size"],
(iter_num * batch_size, trg_seq_len),
).astype("int64"),
"trg_sequence_length": np.random.randint(
1, trg_seq_len, (iter_num * batch_size,)
).astype("int64"),
"label": np.random.randint(
2,
self.model_hparams["src_vocab_size"],
(iter_num * batch_size, trg_seq_len, 1),
).astype("int64"),
}
place = (
core.CUDAPlace(0)
if core.is_compiled_with_cuda()
else core.CPUPlace()
)
self.exe = Executor(place)
def test_beam_search_infer(self):
paddle.set_default_dtype("float32")
paddle.enable_static()
self.model_hparams["decoding_strategy"] = "beam_search"
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
source = fluid.data(name="src", shape=[None, None], dtype="int64")
source_length = fluid.data(
name="src_sequence_length", shape=[None], dtype="int64"
)
model = Seq2SeqModel(**self.model_hparams)
output = model(source, source_length)
self.exe.run(startup_program)
for iter_idx in range(self.iter_num):
trans_ids = self.exe.run(
program=main_program,
feed={
"src": self.data["src"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size,
:,
],
"src_sequence_length": self.data["src_sequence_length"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size
],
},
fetch_list=[output],
)[0]
class ModuleApiTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
......
......@@ -14,26 +14,389 @@
import math
from collections.abc import Sequence
from functools import reduce
from functools import partial, reduce
import numpy as np
import paddle
from paddle import _C_ops, _legacy_C_ops, framework, in_dynamic_mode
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layers import utils
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import control_flow, sequence_lod, utils
from paddle.fluid.layers.utils import flatten, map_structure
from paddle.framework import core
from paddle.nn import Layer
from paddle.nn import functional as F
from paddle.nn import initializer as I
from paddle.static import default_startup_program, program_guard
from paddle.static import Variable, default_startup_program, program_guard
from .container import LayerList
__all__ = []
def rnn(
cell,
inputs,
initial_states=None,
sequence_length=None,
time_major=False,
is_reverse=False,
**kwargs
):
r"""
rnn creates a recurrent neural network specified by RNNCell `cell`,
which performs :code:`cell.call()` (for dygraph mode :code:`cell.forward`)
repeatedly until reaches to the maximum length of `inputs`.
Parameters:
cell(RNNCellBase): An instance of `RNNCellBase`.
inputs(Tensor): the input sequences.
If time_major is True, the shape is
`[time_steps, batch_size, input_size]`
else the shape is `[batch_size, time_steps, input_size]`.
initial_states(Tensor|tuple|list, optional): the initial state of the
rnn cell. Tensor or a possibly nested structure of tensors. If not
provided, `cell.get_initial_states` would be called to produce
the initial state. Defaults to None.
sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64
or int32. The valid lengths of input sequences. Defaults to None.
If `sequence_length` is not None, the inputs are treated as
padded sequences. In each input sequence, elements whose time step
index are not less than the valid length are treated as paddings.
time_major (bool, optional): Whether the first dimension of the input means the
time steps. Defaults to False.
is_reverse (bool, optional): Indicate whether to calculate in the reverse
order of input sequences. Defaults to False.
**kwargs: Additional keyword arguments to pass to `forward` of the cell.
Returns:
outputs (Tensor|list|tuple): the output sequence. Tensor or nested
structure of Tensors.
If `time_major` is True, the shape of each tensor in outpus is
`[time_steps, batch_size, hidden_size]`, else
`[batch_size, time_steps, hidden_size]`.
final_states (Tensor|list|tuple): final states. A (possibly nested structure of)
tensor[s], representing the final state for RNN. It has the same
structure of intial state. Each tensor in final states has the same
shape and dtype as the corresponding tensor in initial states.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
cell = paddle.nn.SimpleRNNCell(16, 32)
inputs = paddle.rand((4, 23, 16))
prev_h = paddle.randn((4, 32))
outputs, final_states = paddle.nn.layer.rnn(cell, inputs, prev_h)
"""
if _non_static_mode():
return _rnn_dynamic_graph(
cell,
inputs,
initial_states,
sequence_length,
time_major,
is_reverse,
**kwargs
)
else:
return _rnn_static_graph(
cell,
inputs,
initial_states,
sequence_length,
time_major,
is_reverse,
**kwargs
)
class ArrayWrapper:
def __init__(self, x):
self.array = [x]
def append(self, x):
self.array.append(x)
return self
def __getitem__(self, item):
return self.array.__getitem__(item)
def _maybe_copy(state, new_state, step_mask):
"""update rnn state or just pass the old state through"""
new_state = paddle.tensor.math._multiply_with_axis(
new_state, step_mask, axis=0
) + paddle.tensor.math._multiply_with_axis(state, (1 - step_mask), axis=0)
return new_state
def _transpose_batch_time(x):
perm = [1, 0] + list(range(2, len(x.shape)))
return paddle.transpose(x, perm)
def _rnn_dynamic_graph(
cell,
inputs,
initial_states=None,
sequence_length=None,
time_major=False,
is_reverse=False,
**kwargs
):
time_step_index = 0 if time_major else 1
flat_inputs = flatten(inputs)
time_steps = flat_inputs[0].shape[time_step_index]
if initial_states is None:
initial_states = cell.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0
)
if not time_major:
inputs = map_structure(_transpose_batch_time, inputs)
if sequence_length is not None:
mask = sequence_lod.sequence_mask(
sequence_length, maxlen=time_steps, dtype=inputs.dtype
)
mask = paddle.transpose(mask, [1, 0])
if is_reverse:
inputs = map_structure(lambda x: paddle.reverse(x, axis=[0]), inputs)
mask = (
paddle.reverse(mask, axis=[0])
if sequence_length is not None
else None
)
states = initial_states
outputs = []
for i in range(time_steps):
step_inputs = map_structure(lambda x: x[i], inputs)
step_outputs, new_states = cell(step_inputs, states, **kwargs)
if sequence_length is not None:
new_states = map_structure(
partial(_maybe_copy, step_mask=mask[i]), states, new_states
)
states = new_states
outputs = (
map_structure(lambda x: ArrayWrapper(x), step_outputs)
if i == 0
else map_structure(
lambda x, x_array: x_array.append(x), step_outputs, outputs
)
)
final_outputs = map_structure(
lambda x: paddle.stack(x.array, axis=time_step_index), outputs
)
if is_reverse:
final_outputs = map_structure(
lambda x: paddle.reverse(x, axis=time_step_index), final_outputs
)
final_states = new_states
return final_outputs, final_states
def _rnn_static_graph(
cell,
inputs,
initial_states=None,
sequence_length=None,
time_major=False,
is_reverse=False,
**kwargs
):
check_type(inputs, 'inputs', (Variable, list, tuple), 'rnn')
if isinstance(inputs, (list, tuple)):
for i, input_x in enumerate(inputs):
check_variable_and_dtype(
input_x, 'inputs[' + str(i) + ']', ['float32', 'float64'], 'rnn'
)
check_type(
initial_states,
'initial_states',
(Variable, list, tuple, type(None)),
'rnn',
)
check_type(
sequence_length, 'sequence_length', (Variable, type(None)), 'rnn'
)
def _switch_grad(x, stop=False):
x.stop_gradient = stop
return x
if initial_states is None:
initial_states = cell.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0
)
initial_states = map_structure(_switch_grad, initial_states)
if not time_major:
inputs = map_structure(_transpose_batch_time, inputs)
if sequence_length:
max_seq_len = paddle.shape(flatten(inputs)[0])[0]
mask = sequence_lod.sequence_mask(
sequence_length,
maxlen=max_seq_len,
dtype=flatten(initial_states)[0].dtype,
)
mask = paddle.transpose(mask, [1, 0])
if is_reverse:
inputs = map_structure(lambda x: paddle.reverse(x, axis=[0]), inputs)
mask = paddle.reverse(mask, axis=[0]) if sequence_length else None
# StaticRNN
rnn = control_flow.StaticRNN()
with rnn.step():
inputs = map_structure(rnn.step_input, inputs)
states = map_structure(rnn.memory, initial_states)
copy_states = map_structure(lambda x: x, states)
outputs, new_states = cell(inputs, copy_states, **kwargs)
utils.assert_same_structure(states, new_states)
if sequence_length:
step_mask = rnn.step_input(mask)
new_states = map_structure(
partial(_maybe_copy, step_mask=step_mask), states, new_states
)
map_structure(rnn.update_memory, states, new_states)
flat_outputs = flatten(outputs)
map_structure(rnn.step_output, outputs)
map_structure(rnn.step_output, new_states)
rnn_out = rnn()
final_outputs = rnn_out[: len(flat_outputs)]
final_outputs = utils.pack_sequence_as(outputs, final_outputs)
final_states = map_structure(lambda x: x[-1], rnn_out[len(flat_outputs) :])
final_states = utils.pack_sequence_as(new_states, final_states)
if is_reverse:
final_outputs = map_structure(
lambda x: paddle.reverse(x, axis=[0]), final_outputs
)
if not time_major:
final_outputs = map_structure(_transpose_batch_time, final_outputs)
return (final_outputs, final_states)
def birnn(
cell_fw,
cell_bw,
inputs,
initial_states=None,
sequence_length=None,
time_major=False,
**kwargs
):
r"""
birnn creates a bidirectional recurrent neural network specified by
RNNCell `cell_fw` and `cell_bw`, which performs :code:`cell.call()`
(for dygraph mode :code:`cell.forward`) repeatedly until reaches to
the maximum length of `inputs` and then concat the outputs for both RNNs
along the last axis.
Parameters:
cell_fw(RNNCellBase): An instance of `RNNCellBase`.
cell_bw(RNNCellBase): An instance of `RNNCellBase`.
inputs(Tensor): the input sequences.
If time_major is True, the shape is
`[time_steps, batch_size, input_size]`
else the shape is `[batch_size, time_steps, input_size]`.
initial_states(tuple, optional): A tuple of initial states of
`cell_fw` and `cell_bw`.
If not provided, `cell.get_initial_states` would be called to
produce initial state for each cell. Defaults to None.
sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64
or int32. The valid lengths of input sequences. Defaults to None.
If `sequence_length` is not None, the inputs are treated as
padded sequences. In each input sequence, elements whose time step
index are not less than the valid length are treated as paddings.
time_major (bool): Whether the first dimension of the input means the
time steps. Defaults to False.
**kwargs: Additional keyword arguments to pass to `forward` of each cell.
Returns:
outputs (Tensor): the outputs of the bidirectional RNN. It is the
concatenation of the outputs from the forward RNN and backward
RNN along the last axis.
If time major is True, the shape is `[time_steps, batch_size, size]`,
else the shape is `[batch_size, time_steps, size]`, where size is
`cell_fw.hidden_size + cell_bw.hidden_size`.
final_states (tuple): A tuple of the final states of the forward
cell and backward cell.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
cell_fw = paddle.nn.LSTMCell(16, 32)
cell_bw = paddle.nn.LSTMCell(16, 32)
inputs = paddle.rand((4, 23, 16))
hf, cf = paddle.rand((4, 32)), paddle.rand((4, 32))
hb, cb = paddle.rand((4, 32)), paddle.rand((4, 32))
initial_states = ((hf, cf), (hb, cb))
outputs, final_states = paddle.nn.layer.birnn(
cell_fw, cell_bw, inputs, initial_states)
"""
if initial_states is None:
states_fw = cell_fw.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0
)
states_bw = cell_fw.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0
)
else:
states_fw, states_bw = initial_states
outputs_fw, states_fw = rnn(
cell_fw,
inputs,
states_fw,
sequence_length,
time_major=time_major,
**kwargs
)
outputs_bw, states_bw = rnn(
cell_bw,
inputs,
states_bw,
sequence_length,
time_major=time_major,
is_reverse=True,
**kwargs
)
outputs = map_structure(
lambda x, y: paddle.concat([x, y], -1), outputs_fw, outputs_bw
)
final_states = (states_fw, states_bw)
return outputs, final_states
def split_states(states, bidirectional=False, state_components=1):
r"""
Split states of RNN network into possibly nested list or tuple of
......@@ -779,7 +1142,7 @@ class RNN(Layer):
def forward(
self, inputs, initial_states=None, sequence_length=None, **kwargs
):
final_outputs, final_states = paddle.fluid.layers.rnn(
final_outputs, final_states = rnn(
self.cell,
inputs,
initial_states=initial_states,
......@@ -866,7 +1229,7 @@ class BiRNN(Layer):
len(initial_states) == 2
), "length of initial_states should be 2 when it is a list/tuple"
outputs, final_states = paddle.fluid.layers.birnn(
outputs, final_states = birnn(
self.cell_fw,
self.cell_bw,
inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册