diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index 8754e5d4d0c8c829303f1fe9cd39ead36619ac3b..ef1d28e59e0815628967eef2459a8555b4af621c 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -248,3 +248,35 @@ class FC(layers.Layer): outputs={"Out": out}, attrs={"use_mkldnn": False}) return out + + +class SimpleRNNCell(layers.Layer): + def __init__(self, step_input_size, hidden_size, output_size, param_attr): + self.input_size = step_input_size + self.hidden_size = hidden_size + self.output_size = output_size + from ..layer_helper import LayerHelper + self._helper = LayerHelper('SimpleRNNCell', param_attr=param_attr) + + def _build_once(self, inputs): + i2h_param_shape = [self.step_input_size, self.hidden_size] + h2h_param_shape = [self.hidden_size, self.hidden_size] + h2o_param_shape = [self.output_size, self.hidden_size] + self._i2h_w = self._helper.create_parameter( + attr=self._helper.param_attr, + shape=i2h_param_shape, + dtype=self._dtype, + is_bias=False) + self._h2h_w = self._helper.create_parameter( + attr=self._helper.param_attr, + shape=h2h_param_shape, + dtype=self._dtype, + is_bias=False) + self._h2o_w = self._helper.create_parameter( + attr=self._helper.param_attr, + shape=h2o_param_shape, + dtype=self._dtype, + is_bias=False) + + def forward(self, inputs): + return 1 diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index 86baff3c589d7b8a14938886b3e2104b0beb1cc9..915b2921d7935e9fa8c608a0674654b5dd13bc5e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -80,6 +80,19 @@ class MLP(fluid.imperative.Layer): return x +class SimpleRNN(fluid.imperative.Layer): + def __init__(self, inputs): + super(SimpleRNN, self).__init__() + self.seq_len = input.shape[0] + self._fc1 = FC(3, + fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.1))) + + def forward(self, inputs): + for i in range(self.seq_len): + x = self._fc1(inputs[i]) + + class TestImperative(unittest.TestCase): def test_layer(self): with fluid.imperative.guard(): @@ -210,6 +223,9 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.allclose(dy_out, static_out)) self.assertTrue(np.allclose(dy_grad, static_grad)) + def test_rnn_ptb(self): + np_inp = np.arrary([]) + if __name__ == '__main__': unittest.main()