提交 6e962618 编写于 作者: G guosheng

Add unit test for StackedRNNCell.

上级 f75b39e8
...@@ -567,6 +567,40 @@ class TestSequenceTaggingInfer(TestSequenceTagging): ...@@ -567,6 +567,40 @@ class TestSequenceTaggingInfer(TestSequenceTagging):
return inputs return inputs
class TestStackedRNN(ModuleApiTest):
def setUp(self):
shape = (2, 4, 16)
self.inputs = [np.random.random(shape).astype("float32")]
self.outputs = None
self.attrs = {"input_size": 16, "hidden_size": 16, "num_layers": 2}
self.param_states = {}
@staticmethod
def model_init(self, input_size, hidden_size, num_layers):
cells = [
BasicLSTMCell(input_size, hidden_size),
BasicLSTMCell(hidden_size, hidden_size)
]
stacked_cell = StackedRNNCell(cells)
self.lstm = RNN(stacked_cell)
@staticmethod
def model_forward(self, inputs):
return self.lstm(inputs)[0]
def make_inputs(self):
inputs = [
Input(
[None, None, self.inputs[-1].shape[-1]],
"float32",
name="input"),
]
return inputs
def test_check_output(self):
self.check_output()
class TestLSTM(ModuleApiTest): class TestLSTM(ModuleApiTest):
def setUp(self): def setUp(self):
shape = (2, 4, 16) shape = (2, 4, 16)
......
...@@ -49,6 +49,8 @@ __all__ = [ ...@@ -49,6 +49,8 @@ __all__ = [
'BasicLSTMCell', 'BasicLSTMCell',
'BasicGRUCell', 'BasicGRUCell',
'RNN', 'RNN',
'BidirectionalRNN',
'StackedRNNCell',
'StackedLSTMCell', 'StackedLSTMCell',
'LSTM', 'LSTM',
'BidirectionalLSTM', 'BidirectionalLSTM',
...@@ -1025,6 +1027,7 @@ class StackedRNNCell(RNNCell): ...@@ -1025,6 +1027,7 @@ class StackedRNNCell(RNNCell):
""" """
def __init__(self, cells): def __init__(self, cells):
super(StackedRNNCell, self).__init__()
self.cells = [] self.cells = []
for i, cell in enumerate(cells): for i, cell in enumerate(cells):
self.cells.append(self.add_sublayer("cell_%d" % i, cell)) self.cells.append(self.add_sublayer("cell_%d" % i, cell))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册