提交 200776bd 编写于 作者: J JiabinYang

add simple rnn

上级 0b6447a4
...@@ -248,3 +248,35 @@ class FC(layers.Layer): ...@@ -248,3 +248,35 @@ class FC(layers.Layer):
outputs={"Out": out}, outputs={"Out": out},
attrs={"use_mkldnn": False}) attrs={"use_mkldnn": False})
return out return out
class SimpleRNNCell(layers.Layer):
def __init__(self, step_input_size, hidden_size, output_size, param_attr):
self.input_size = step_input_size
self.hidden_size = hidden_size
self.output_size = output_size
from ..layer_helper import LayerHelper
self._helper = LayerHelper('SimpleRNNCell', param_attr=param_attr)
def _build_once(self, inputs):
i2h_param_shape = [self.step_input_size, self.hidden_size]
h2h_param_shape = [self.hidden_size, self.hidden_size]
h2o_param_shape = [self.output_size, self.hidden_size]
self._i2h_w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=i2h_param_shape,
dtype=self._dtype,
is_bias=False)
self._h2h_w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=h2h_param_shape,
dtype=self._dtype,
is_bias=False)
self._h2o_w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=h2o_param_shape,
dtype=self._dtype,
is_bias=False)
def forward(self, inputs):
return 1
...@@ -80,6 +80,19 @@ class MLP(fluid.imperative.Layer): ...@@ -80,6 +80,19 @@ class MLP(fluid.imperative.Layer):
return x return x
class SimpleRNN(fluid.imperative.Layer):
def __init__(self, inputs):
super(SimpleRNN, self).__init__()
self.seq_len = input.shape[0]
self._fc1 = FC(3,
fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)))
def forward(self, inputs):
for i in range(self.seq_len):
x = self._fc1(inputs[i])
class TestImperative(unittest.TestCase): class TestImperative(unittest.TestCase):
def test_layer(self): def test_layer(self):
with fluid.imperative.guard(): with fluid.imperative.guard():
...@@ -210,6 +223,9 @@ class TestImperative(unittest.TestCase): ...@@ -210,6 +223,9 @@ class TestImperative(unittest.TestCase):
self.assertTrue(np.allclose(dy_out, static_out)) self.assertTrue(np.allclose(dy_out, static_out))
self.assertTrue(np.allclose(dy_grad, static_grad)) self.assertTrue(np.allclose(dy_grad, static_grad))
def test_rnn_ptb(self):
np_inp = np.arrary([])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册