diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index ef1d28e59e0815628967eef2459a8555b4af621c..24f1865f3d1d9a010194333c2bb5abacd992c795 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -251,10 +251,16 @@ class FC(layers.Layer): class SimpleRNNCell(layers.Layer): - def __init__(self, step_input_size, hidden_size, output_size, param_attr): + def __init__(self, + step_input_size, + hidden_size, + output_size, + param_attr, + dtype=core.VarDesc.VarType.FP32): self.input_size = step_input_size self.hidden_size = hidden_size self.output_size = output_size + self._dype = core.VarDesc.VarType.FP32 from ..layer_helper import LayerHelper self._helper = LayerHelper('SimpleRNNCell', param_attr=param_attr) @@ -279,4 +285,19 @@ class SimpleRNNCell(layers.Layer): is_bias=False) def forward(self, inputs): + input = inputs[0] + pre_hidden = inputs[1] + out = self._helper.create_variable_for_type_inference(self._dtype) + hidden = self._helper.create_variable_for_type_inference(self._dype) + + self._helper.append_op( + type="mul", + inputs={"X": input, + "Y": self._w}, + outputs={"Out": out}, + attrs={ + "x_num_col_dims": self._num_flatten_dims, + "y_num_col_dims": 1 + }) + return 1