提交 a360f143 编写于 作者: J JiabinYang

little change

上级 7e3280ad
...@@ -251,10 +251,16 @@ class FC(layers.Layer): ...@@ -251,10 +251,16 @@ class FC(layers.Layer):
class SimpleRNNCell(layers.Layer): class SimpleRNNCell(layers.Layer):
def __init__(self, step_input_size, hidden_size, output_size, param_attr): def __init__(self,
step_input_size,
hidden_size,
output_size,
param_attr,
dtype=core.VarDesc.VarType.FP32):
self.input_size = step_input_size self.input_size = step_input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.output_size = output_size self.output_size = output_size
self._dype = core.VarDesc.VarType.FP32
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
self._helper = LayerHelper('SimpleRNNCell', param_attr=param_attr) self._helper = LayerHelper('SimpleRNNCell', param_attr=param_attr)
...@@ -279,4 +285,19 @@ class SimpleRNNCell(layers.Layer): ...@@ -279,4 +285,19 @@ class SimpleRNNCell(layers.Layer):
is_bias=False) is_bias=False)
def forward(self, inputs): def forward(self, inputs):
input = inputs[0]
pre_hidden = inputs[1]
out = self._helper.create_variable_for_type_inference(self._dtype)
hidden = self._helper.create_variable_for_type_inference(self._dype)
self._helper.append_op(
type="mul",
inputs={"X": input,
"Y": self._w},
outputs={"Out": out},
attrs={
"x_num_col_dims": self._num_flatten_dims,
"y_num_col_dims": 1
})
return 1 return 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册