提交 d88cbf75 编写于 作者: G guosheng

Add StackedRNN and BiRNN.

上级 eb20b652
...@@ -25,7 +25,6 @@ from paddle.fluid.dygraph import Embedding, Linear, Layer ...@@ -25,7 +25,6 @@ from paddle.fluid.dygraph import Embedding, Linear, Layer
from paddle.fluid.layers import BeamSearchDecoder from paddle.fluid.layers import BeamSearchDecoder
import hapi.text as text import hapi.text as text
from hapi.model import Model, Input, set_device from hapi.model import Model, Input, set_device
# from hapi.text.text import BasicLSTMCell, BasicGRUCell, RNN, DynamicDecode, MultiHeadAttention, TransformerEncoder, TransformerCell
from hapi.text.text import * from hapi.text.text import *
...@@ -515,15 +514,142 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest): ...@@ -515,15 +514,142 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest):
class TestSequenceTagging(ModuleApiTest): class TestSequenceTagging(ModuleApiTest):
def setUp(self): def setUp(self):
shape = (2, 4, 128) self.inputs = [
np.random.randint(0, 100, (2, 8)).astype("int64"),
np.random.randint(1, 8, (2)).astype("int64"),
np.random.randint(0, 5, (2, 8)).astype("int64")
]
self.outputs = None
self.attrs = {"vocab_size": 100, "num_labels": 5}
self.param_states = {}
@staticmethod
def model_init(self,
vocab_size,
num_labels,
word_emb_dim=128,
grnn_hidden_dim=128,
emb_learning_rate=0.1,
crf_learning_rate=0.1,
bigru_num=2,
init_bound=0.1):
self.tagger = SequenceTagging(vocab_size, num_labels, word_emb_dim,
grnn_hidden_dim, emb_learning_rate,
crf_learning_rate, bigru_num, init_bound)
@staticmethod
def model_forward(self, word, lengths, target=None):
return self.tagger(word, lengths, target)
def make_inputs(self):
inputs = [
Input(
[None, None], "int64", name="word"),
Input(
[None], "int64", name="lengths"),
Input(
[None, None], "int64", name="target"),
]
return inputs
def test_check_output(self):
self.check_output()
class TestSequenceTaggingInfer(TestSequenceTagging):
def setUp(self):
super(TestSequenceTaggingInfer, self).setUp()
self.inputs = self.inputs[:2] # remove target
def make_inputs(self):
inputs = super(TestSequenceTaggingInfer,
self).make_inputs()[:2] # remove target
return inputs
class TestLSTM(ModuleApiTest):
def setUp(self):
shape = (2, 4, 16)
self.inputs = [np.random.random(shape).astype("float32")] self.inputs = [np.random.random(shape).astype("float32")]
self.outputs = None self.outputs = None
self.attrs = {"input_size": 128, "hidden_size": 128} self.attrs = {"input_size": 16, "hidden_size": 16, "num_layers": 2}
self.param_states = {} self.param_states = {}
@staticmethod @staticmethod
def model_init(self, input_size, hidden_size): def model_init(self, input_size, hidden_size, num_layers):
self.module = SequenceTagging(input_size, hidden_size) self.lstm = LSTM(input_size, hidden_size, num_layers=num_layers)
@staticmethod
def model_forward(self, inputs):
return self.lstm(inputs)[0]
def make_inputs(self):
inputs = [
Input(
[None, None, self.inputs[-1].shape[-1]],
"float32",
name="input"),
]
return inputs
def test_check_output(self):
self.check_output()
class TestBiLSTM(ModuleApiTest):
def setUp(self):
shape = (2, 4, 16)
self.inputs = [np.random.random(shape).astype("float32")]
self.outputs = None
self.attrs = {"input_size": 16, "hidden_size": 16, "num_layers": 2}
self.param_states = {}
@staticmethod
def model_init(self,
input_size,
hidden_size,
num_layers,
merge_mode="concat",
merge_each_layer=False):
self.bilstm = BidirectionalLSTM(
input_size,
hidden_size,
num_layers=num_layers,
merge_mode=merge_mode,
merge_each_layer=merge_each_layer)
@staticmethod
def model_forward(self, inputs):
return self.bilstm(inputs)[0]
def make_inputs(self):
inputs = [
Input(
[None, None, self.inputs[-1].shape[-1]],
"float32",
name="input"),
]
return inputs
def test_check_output_merge0(self):
self.check_output()
def test_check_output_merge1(self):
self.attrs["merge_each_layer"] = True
self.check_output()
class TestGRU(ModuleApiTest):
def setUp(self):
shape = (2, 4, 64)
self.inputs = [np.random.random(shape).astype("float32")]
self.outputs = None
self.attrs = {"input_size": 64, "hidden_size": 128, "num_layers": 2}
self.param_states = {}
@staticmethod
def model_init(self, input_size, hidden_size, num_layers):
self.gru = GRU(input_size, hidden_size, num_layers=num_layers)
@staticmethod @staticmethod
def model_forward(self, inputs): def model_forward(self, inputs):
...@@ -542,5 +668,48 @@ class TestSequenceTagging(ModuleApiTest): ...@@ -542,5 +668,48 @@ class TestSequenceTagging(ModuleApiTest):
self.check_output() self.check_output()
class TestBiGRU(ModuleApiTest):
def setUp(self):
shape = (2, 4, 64)
self.inputs = [np.random.random(shape).astype("float32")]
self.outputs = None
self.attrs = {"input_size": 64, "hidden_size": 128, "num_layers": 2}
self.param_states = {}
@staticmethod
def model_init(self,
input_size,
hidden_size,
num_layers,
merge_mode="concat",
merge_each_layer=False):
self.bigru = BidirectionalGRU(
input_size,
hidden_size,
num_layers=num_layers,
merge_mode=merge_mode,
merge_each_layer=merge_each_layer)
@staticmethod
def model_forward(self, inputs):
return self.bigru(inputs)[0]
def make_inputs(self):
inputs = [
Input(
[None, None, self.inputs[-1].shape[-1]],
"float32",
name="input"),
]
return inputs
def test_check_output_merge0(self):
self.check_output()
def test_check_output_merge1(self):
self.attrs["merge_each_layer"] = True
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -49,7 +49,9 @@ __all__ = [ ...@@ -49,7 +49,9 @@ __all__ = [
'BeamSearchDecoder', 'MultiHeadAttention', 'FFN', 'BeamSearchDecoder', 'MultiHeadAttention', 'FFN',
'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer', 'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer',
'TransformerDecoder', 'TransformerCell', 'TransformerBeamSearchDecoder', 'TransformerDecoder', 'TransformerCell', 'TransformerBeamSearchDecoder',
'LinearChainCRF', 'CRFDecoding', 'SequenceTagging', 'GRUEncoder' 'LinearChainCRF', 'CRFDecoding', 'SequenceTagging', 'GRUEncoder',
'StackedLSTMCell', 'LSTM', 'BidirectionalLSTM', 'StackedGRUCell', 'GRU',
'BidirectionalGRU'
] ]
...@@ -241,7 +243,7 @@ class BasicLSTMCell(RNNCell): ...@@ -241,7 +243,7 @@ class BasicLSTMCell(RNNCell):
# TODO(guosheng): find better way to resolve constants in __init__ # TODO(guosheng): find better way to resolve constants in __init__
self._forget_bias = layers.create_global_var( self._forget_bias = layers.create_global_var(
shape=[1], dtype=dtype, value=forget_bias, persistable=True) shape=[1], dtype=dtype, value=forget_bias, persistable=True)
self._forget_bias.stop_gradient = False self._forget_bias.stop_gradient = True
self._dtype = dtype self._dtype = dtype
self._input_size = input_size self._input_size = input_size
...@@ -468,9 +470,11 @@ class BasicLSTMCell(RNNCell): ...@@ -468,9 +470,11 @@ class BasicLSTMCell(RNNCell):
new_cell = layers.elementwise_add( new_cell = layers.elementwise_add(
layers.elementwise_mul( layers.elementwise_mul(
pre_cell, pre_cell,
layers.sigmoid(layers.elementwise_add(f, self._forget_bias))), self._gate_activation(
layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j))) layers.elementwise_add(f, self._forget_bias))),
new_hidden = layers.tanh(new_cell) * layers.sigmoid(o) layers.elementwise_mul(
self._gate_activation(i), self._activation(j)))
new_hidden = self._activation(new_cell) * self._gate_activation(o)
return new_hidden, [new_hidden, new_cell] return new_hidden, [new_hidden, new_cell]
...@@ -1037,9 +1041,9 @@ class TransformerCell(Layer): ...@@ -1037,9 +1041,9 @@ class TransformerCell(Layer):
@property @property
def state_shape(self): def state_shape(self):
return [{ return [{
"k": [self.n_head, 0, self.d_key], "k": [self.decoder.n_head, 0, self.decoder.d_key],
"v": [self.n_head, 0, self.d_value], "v": [self.decoder.n_head, 0, self.decoder.d_value],
} for i in range(len(self.n_layer))] } for i in range(len(self.decoder.n_layer))]
class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
...@@ -1787,10 +1791,8 @@ class GRUEncoder(Layer): ...@@ -1787,10 +1791,8 @@ class GRUEncoder(Layer):
grnn_hidden_dim, grnn_hidden_dim,
init_bound, init_bound,
num_layers=1, num_layers=1,
h_0=None,
is_bidirection=False): is_bidirection=False):
super(GRUEncoder, self).__init__() super(GRUEncoder, self).__init__()
self.h_0 = h_0
self.num_layers = num_layers self.num_layers = num_layers
self.is_bidirection = is_bidirection self.is_bidirection = is_bidirection
self.gru_list = [] self.gru_list = []
...@@ -1827,7 +1829,7 @@ class GRUEncoder(Layer): ...@@ -1827,7 +1829,7 @@ class GRUEncoder(Layer):
is_reverse=True, is_reverse=True,
time_major=False))) time_major=False)))
def forward(self, input_feature): def forward(self, input_feature, h0=None):
for i in range(self.num_layers): for i in range(self.num_layers):
pre_gru, pre_state = self.gru_list[i](input_feature) pre_gru, pre_state = self.gru_list[i](input_feature)
if self.is_bidirection: if self.is_bidirection:
...@@ -1839,18 +1841,16 @@ class GRUEncoder(Layer): ...@@ -1839,18 +1841,16 @@ class GRUEncoder(Layer):
return out return out
class SequenceTagging(fluid.dygraph.Layer): class SequenceTagging(Layer):
def __init__(self, def __init__(self,
vocab_size, vocab_size,
num_labels, num_labels,
batch_size,
word_emb_dim=128, word_emb_dim=128,
grnn_hidden_dim=128, grnn_hidden_dim=128,
emb_learning_rate=0.1, emb_learning_rate=0.1,
crf_learning_rate=0.1, crf_learning_rate=0.1,
bigru_num=2, bigru_num=2,
init_bound=0.1, init_bound=0.1):
length=None):
super(SequenceTagging, self).__init__() super(SequenceTagging, self).__init__()
""" """
define the sequence tagging network structure define the sequence tagging network structure
...@@ -1868,7 +1868,6 @@ class SequenceTagging(fluid.dygraph.Layer): ...@@ -1868,7 +1868,6 @@ class SequenceTagging(fluid.dygraph.Layer):
self.emb_lr = emb_learning_rate self.emb_lr = emb_learning_rate
self.crf_lr = crf_learning_rate self.crf_lr = crf_learning_rate
self.bigru_num = bigru_num self.bigru_num = bigru_num
self.batch_size = batch_size
self.init_bound = 0.1 self.init_bound = 0.1
self.word_embedding = Embedding( self.word_embedding = Embedding(
...@@ -1880,20 +1879,11 @@ class SequenceTagging(fluid.dygraph.Layer): ...@@ -1880,20 +1879,11 @@ class SequenceTagging(fluid.dygraph.Layer):
initializer=fluid.initializer.Uniform( initializer=fluid.initializer.Uniform(
low=-self.init_bound, high=self.init_bound))) low=-self.init_bound, high=self.init_bound)))
h_0 = fluid.layers.create_global_var(
shape=[self.batch_size, self.grnn_hidden_dim],
value=0.0,
dtype='float32',
persistable=True,
force_cpu=True,
name='h_0')
self.gru_encoder = GRUEncoder( self.gru_encoder = GRUEncoder(
input_dim=self.grnn_hidden_dim, input_dim=self.grnn_hidden_dim,
grnn_hidden_dim=self.grnn_hidden_dim, grnn_hidden_dim=self.grnn_hidden_dim,
init_bound=self.init_bound, init_bound=self.init_bound,
num_layers=self.bigru_num, num_layers=self.bigru_num,
h_0=h_0,
is_bidirection=True) is_bidirection=True)
self.fc = Linear( self.fc = Linear(
...@@ -1936,3 +1926,426 @@ class SequenceTagging(fluid.dygraph.Layer): ...@@ -1936,3 +1926,426 @@ class SequenceTagging(fluid.dygraph.Layer):
self.linear_chain_crf.weight = self.crf_decoding.weight self.linear_chain_crf.weight = self.crf_decoding.weight
crf_decode = self.crf_decoding(input=emission, length=lengths) crf_decode = self.crf_decoding(input=emission, length=lengths)
return crf_decode, lengths return crf_decode, lengths
class StackedRNNCell(RNNCell):
def __init__(self, cells):
self.cells = []
for i, cell in enumerate(cells):
self.cells.append(self.add_sublayer("cell_%d" % i, cell))
def forward(self, inputs, states):
pass
@staticmethod
def stack_param_attr(param_attr, n):
if isinstance(param_attr, (list, tuple)):
assert len(param_attr) == n, (
"length of param_attr should be %d when it is a list/tuple" %
n)
param_attrs = [
fluid.ParamAttr._to_attr(attr) for attr in param_attr
]
else:
param_attrs = []
attr = fluid.ParamAttr._to_attr(param_attr)
for i in range(n):
attr_i = copy.deepcopy(attr)
if attr.name:
attr_i.name = attr_i.name + "_" + str(i)
param_attrs.append(attr_i)
return param_attrs
@property
def state_shape(self):
return [cell.state_shape for cell in self.cells]
class StackedLSTMCell(RNNCell):
"""
"""
def __init__(self,
input_size,
hidden_size,
gate_activation=None,
activation=None,
forget_bias=1.0,
num_layers=1,
dropout=0.0,
param_attr=None,
bias_attr=None,
dtype="float32"):
super(StackedLSTMCell, self).__init__()
self.dropout = utils.convert_to_list(dropout, num_layers, "dropout",
float)
param_attrs = StackedRNNCell.stack_param_attr(param_attr, num_layers)
bias_attrs = StackedRNNCell.stack_param_attr(bias_attr, num_layers)
self.cells = []
for i in range(num_layers):
if forget_bias is True:
bias_attrs[
i].initializer = fluid.initializer.NumpyArrayInitializer(
np.concatenate(
np.zeros(2 * hidden_size),
np.ones(hidden_size), np.zeros(hidden_size))
.astype(dtype))
forget_bias = 0.0
self.cells.append(
self.add_sublayer(
"lstm_%d" % i,
BasicLSTMCell(
input_size=input_size if i == 0 else hidden_size,
hidden_size=hidden_size,
gate_activation=gate_activation,
activation=activation,
forget_bias=forget_bias,
param_attr=param_attrs[i],
bias_attr=bias_attrs[i],
dtype=dtype)))
def forward(self, step_input, states):
new_states = []
for i, cell in enumerate(self.cells):
out, new_state = cell(step_input, states[i])
step_input = layers.dropout(
out,
self.dropout[i],
dropout_implementation='upscale_in_train') if self.dropout[
i] > 0 else out
new_states.append(new_state)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.cells]
class LSTM(Layer):
def __init__(self,
input_size,
hidden_size,
gate_activation=None,
activation=None,
forget_bias=1.0,
num_layers=1,
dropout=0.0,
is_reverse=False,
time_major=False,
param_attr=None,
bias_attr=None,
dtype='float32'):
super(LSTM, self).__init__()
lstm_cell = StackedLSTMCell(input_size, hidden_size, gate_activation,
activation, forget_bias, num_layers,
dropout, param_attr, bias_attr, dtype)
self.lstm = RNN(lstm_cell, is_reverse, time_major)
def forward(self, inputs, initial_states=None, sequence_length=None):
return self.lstm(inputs, initial_states, sequence_length)
class BidirectionalRNN(Layer):
def __init__(self,
cell_fw,
cell_bw,
merge_mode='concat',
time_major=False,
cell_cls=None,
**kwargs):
super(BidirectionalRNN, self).__init__()
self.rnn_fw = RNN(cell_fw, is_reverse=False, time_major=time_major)
self.rnn_bw = RNN(cell_bw, is_reverse=True, time_major=time_major)
if merge_mode == 'concat':
self.merge_func = lambda x, y: layers.concat([x, y], -1)
elif merge_mode == 'sum':
self.merge_func = lambda x, y: layers.elementwise_add(x, y)
elif merge_mode == 'ave':
self.merge_func = lambda x, y: layers.scale(
layers.elementwise_add(x, y), 0.5)
elif merge_mode == 'mul':
self.merge_func = lambda x, y: layers.elementwise_mul(x, y)
elif merge_mode == 'zip':
self.merge_func = lambda x, y: (x, y)
elif merge_mode is None:
self.merge_func = None
else:
raise ValueError('Unsupported value for `merge_mode`: %s' %
merge_mode)
def forward(self, inputs, initial_states=None, sequence_length=None):
if isinstance(initial_states, (list, tuple)):
assert len(
initial_states
) == 2, "length of initial_states should be 2 when it is a list/tuple"
else:
initial_states = [initial_states, initial_states]
outputs_fw, states_fw = self.rnn_fw(inputs, initial_states[0],
sequence_length)
outputs_bw, states_bw = self.rnn_bw(inputs, initial_states[1],
sequence_length)
outputs = map_structure(
self.merge_func, outputs_fw,
outputs_bw) if self.merge_func else (outputs_fw, outputs_bw)
return outputs, (states_fw, states_bw)
@staticmethod
def bidirect_param_attr(param_attr):
if isinstance(param_attr, (list, tuple)):
assert len(
param_attr
) == 2, "length of param_attr should be 2 when it is a list/tuple"
param_attrs = param_attr
else:
param_attrs = []
attr = fluid.ParamAttr._to_attr(param_attr)
attr_fw = copy.deepcopy(attr)
if attr.name:
attr_fw.name = attr_fw.name + "_fw"
param_attrs.append(attr_fw)
attr_bw = copy.deepcopy(attr)
if attr.name:
attr_bw.name = attr_bw.name + "_bw"
param_attrs.append(attr_bw)
return param_attrs
class BidirectionalLSTM(Layer):
def __init__(self,
input_size,
hidden_size,
gate_activation=None,
activation=None,
forget_bias=1.0,
num_layers=1,
dropout=0.0,
merge_mode='concat',
merge_each_layer=False,
time_major=False,
param_attr=None,
bias_attr=None,
dtype='float32'):
super(BidirectionalLSTM, self).__init__()
self.num_layers = num_layers
self.merge_mode = merge_mode
self.merge_each_layer = merge_each_layer
param_attrs = BidirectionalRNN.bidirect_param_attr(param_attr)
bias_attrs = BidirectionalRNN.bidirect_param_attr(bias_attr)
if not merge_each_layer:
cell_fw = StackedLSTMCell(input_size, hidden_size, gate_activation,
activation, forget_bias, num_layers,
dropout, param_attrs[0], bias_attrs[0],
dtype)
cell_bw = StackedLSTMCell(input_size, hidden_size, gate_activation,
activation, forget_bias, num_layers,
dropout, param_attrs[1], bias_attrs[1],
dtype)
self.lstm = BidirectionalRNN(
cell_fw, cell_bw, merge_mode=merge_mode, time_major=time_major)
else:
fw_param_attrs = StackedRNNCell.stack_param_attr(param_attrs[0],
num_layers)
bw_param_attrs = StackedRNNCell.stack_param_attr(param_attrs[1],
num_layers)
fw_bias_attrs = StackedRNNCell.stack_param_attr(bias_attrs[0],
num_layers)
bw_bias_attrs = StackedRNNCell.stack_param_attr(bias_attrs[1],
num_layers)
# maybe design cell including both forward and backward later
self.lstm = []
for i in range(num_layers):
cell_fw = StackedLSTMCell(
input_size if i == 0 else (hidden_size * 2
if merge_mode == 'concat' else
hidden_size), hidden_size,
gate_activation, activation, forget_bias, 1, dropout,
fw_param_attrs[i], fw_bias_attrs[i], dtype)
cell_bw = StackedLSTMCell(
input_size if i == 0 else (hidden_size * 2
if merge_mode == 'concat' else
hidden_size), hidden_size,
gate_activation, activation, forget_bias, 1, dropout,
bw_param_attrs[i], bw_bias_attrs[i], dtype)
self.lstm.append(
self.add_sublayer(
"lstm_%d" % i,
BidirectionalRNN(
cell_fw,
cell_bw,
merge_mode=merge_mode,
time_major=time_major)))
def forward(self, inputs, initial_states=None, sequence_length=None):
if not self.merge_each_layer:
return self.lstm(inputs, initial_states, sequence_length)
else:
if isinstance(initial_states, (list, tuple)):
assert len(initial_states) == self.num_layers, (
"length of initial_states should be %d when it is a list/tuple"
% self.num_layers)
else:
initial_states = [initial_states] * self.num_layers
stacked_states = []
for i in range(self.num_layers):
outputs, states = self.lstm[i](inputs, initial_states[i],
sequence_length)
inputs = outputs
stacked_states.append(states)
return outputs, stacked_states
class StackedGRUCell(RNNCell):
"""
"""
def __init__(self,
input_size,
hidden_size,
gate_activation=None,
activation=None,
num_layers=1,
dropout=0.0,
param_attr=None,
bias_attr=None,
dtype="float32"):
super(StackedGRUCell, self).__init__()
self.dropout = utils.convert_to_list(dropout, num_layers, "dropout",
float)
param_attrs = StackedRNNCell.stack_param_attr(param_attr, num_layers)
bias_attrs = StackedRNNCell.stack_param_attr(bias_attr, num_layers)
self.cells = []
for i in range(num_layers):
self.cells.append(
self.add_sublayer(
"gru_%d" % i,
BasicGRUCell(
input_size=input_size if i == 0 else hidden_size,
hidden_size=hidden_size,
gate_activation=gate_activation,
activation=activation,
param_attr=param_attrs[i],
bias_attr=bias_attrs[i],
dtype=dtype)))
def forward(self, step_input, states):
new_states = []
for i, cell in enumerate(self.cells):
out, new_state = cell(step_input, states[i])
step_input = layers.dropout(
out,
self.dropout[i],
dropout_implementation='upscale_in_train') if self.dropout[
i] > 0 else out
new_states.append(new_state)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.cells]
class GRU(Layer):
def __init__(self,
input_size,
hidden_size,
gate_activation=None,
activation=None,
num_layers=1,
dropout=0.0,
is_reverse=False,
time_major=False,
param_attr=None,
bias_attr=None,
dtype='float32'):
super(GRU, self).__init__()
gru_cell = StackedGRUCell(input_size, hidden_size, gate_activation,
activation, num_layers, dropout, param_attr,
bias_attr, dtype)
self.gru = RNN(gru_cell, is_reverse, time_major)
def forward(self, inputs, initial_states=None, sequence_length=None):
return self.gru(inputs, initial_states, sequence_length)
class BidirectionalGRU(Layer):
def __init__(self,
input_size,
hidden_size,
gate_activation=None,
activation=None,
forget_bias=1.0,
num_layers=1,
dropout=0.0,
merge_mode='concat',
merge_each_layer=False,
time_major=False,
param_attr=None,
bias_attr=None,
dtype='float32'):
super(BidirectionalGRU, self).__init__()
self.num_layers = num_layers
self.merge_mode = merge_mode
self.merge_each_layer = merge_each_layer
param_attrs = BidirectionalRNN.bidirect_param_attr(param_attr)
bias_attrs = BidirectionalRNN.bidirect_param_attr(bias_attr)
if not merge_each_layer:
cell_fw = StackedGRUCell(input_size, hidden_size, gate_activation,
activation, num_layers, dropout,
param_attrs[0], bias_attrs[0], dtype)
cell_bw = StackedGRUCell(input_size, hidden_size, gate_activation,
activation, num_layers, dropout,
param_attrs[1], bias_attrs[1], dtype)
self.gru = BidirectionalRNN(
cell_fw, cell_bw, merge_mode=merge_mode, time_major=time_major)
else:
fw_param_attrs = StackedRNNCell.stack_param_attr(param_attrs[0],
num_layers)
bw_param_attrs = StackedRNNCell.stack_param_attr(param_attrs[1],
num_layers)
fw_bias_attrs = StackedRNNCell.stack_param_attr(bias_attrs[0],
num_layers)
bw_bias_attrs = StackedRNNCell.stack_param_attr(bias_attrs[1],
num_layers)
# maybe design cell including both forward and backward later
self.gru = []
for i in range(num_layers):
cell_fw = StackedGRUCell(input_size if i == 0 else (
hidden_size * 2 if merge_mode == 'concat' else
hidden_size), hidden_size, gate_activation, activation, 1,
dropout, fw_param_attrs[i],
fw_bias_attrs[i], dtype)
cell_bw = StackedGRUCell(input_size if i == 0 else (
hidden_size * 2 if merge_mode == 'concat' else
hidden_size), hidden_size, gate_activation, activation, 1,
dropout, bw_param_attrs[i],
bw_bias_attrs[i], dtype)
self.gru.append(
self.add_sublayer(
"gru_%d" % i,
BidirectionalRNN(
cell_fw,
cell_bw,
merge_mode=merge_mode,
time_major=time_major)))
def forward(self, inputs, initial_states=None, sequence_length=None):
if not self.merge_each_layer:
return self.gru(inputs, initial_states, sequence_length)
else:
if isinstance(initial_states, (list, tuple)):
assert len(initial_states) == self.num_layers, (
"length of initial_states should be %d when it is a list/tuple"
% self.num_layers)
else:
initial_states = [initial_states] * self.num_layers
stacked_states = []
for i in range(self.num_layers):
outputs, states = self.gru[i](inputs, initial_states[i],
sequence_length)
inputs = outputs
stacked_states.append(states)
return outputs, stacked_states
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册