From 6e96261885e117d3b38fc11a2f43087c48975009 Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 12 May 2020 22:57:34 +0800 Subject: [PATCH] Add unit test for StackedRNNCell. --- hapi/tests/test_text.py | 34 ++++++++++++++++++++++++++++++++++ hapi/text/text.py | 3 +++ 2 files changed, 37 insertions(+) diff --git a/hapi/tests/test_text.py b/hapi/tests/test_text.py index 977656c..9e5d8b0 100644 --- a/hapi/tests/test_text.py +++ b/hapi/tests/test_text.py @@ -567,6 +567,40 @@ class TestSequenceTaggingInfer(TestSequenceTagging): return inputs +class TestStackedRNN(ModuleApiTest): + def setUp(self): + shape = (2, 4, 16) + self.inputs = [np.random.random(shape).astype("float32")] + self.outputs = None + self.attrs = {"input_size": 16, "hidden_size": 16, "num_layers": 2} + self.param_states = {} + + @staticmethod + def model_init(self, input_size, hidden_size, num_layers): + cells = [ + BasicLSTMCell(input_size, hidden_size), + BasicLSTMCell(hidden_size, hidden_size) + ] + stacked_cell = StackedRNNCell(cells) + self.lstm = RNN(stacked_cell) + + @staticmethod + def model_forward(self, inputs): + return self.lstm(inputs)[0] + + def make_inputs(self): + inputs = [ + Input( + [None, None, self.inputs[-1].shape[-1]], + "float32", + name="input"), + ] + return inputs + + def test_check_output(self): + self.check_output() + + class TestLSTM(ModuleApiTest): def setUp(self): shape = (2, 4, 16) diff --git a/hapi/text/text.py b/hapi/text/text.py index 860d3e7..b9569bd 100644 --- a/hapi/text/text.py +++ b/hapi/text/text.py @@ -49,6 +49,8 @@ __all__ = [ 'BasicLSTMCell', 'BasicGRUCell', 'RNN', + 'BidirectionalRNN', + 'StackedRNNCell', 'StackedLSTMCell', 'LSTM', 'BidirectionalLSTM', @@ -1025,6 +1027,7 @@ class StackedRNNCell(RNNCell): """ def __init__(self, cells): + super(StackedRNNCell, self).__init__() self.cells = [] for i, cell in enumerate(cells): self.cells.append(self.add_sublayer("cell_%d" % i, cell)) -- GitLab