提交 1d707273 编写于 作者: H huangyuxin

fix the bug of sharing cell in BiGRU and BIRNN

上级 7181e427
...@@ -29,13 +29,13 @@ __all__ = ['RNNStack'] ...@@ -29,13 +29,13 @@ __all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase): class RNNCell(nn.RNNCellBase):
r""" r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states. computes the outputs and updates states.
The formula used is as follows: The formula used is as follows:
.. math:: .. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`. where :math:`act` is for :attr:`activation`.
""" """
...@@ -92,7 +92,7 @@ class RNNCell(nn.RNNCellBase): ...@@ -92,7 +92,7 @@ class RNNCell(nn.RNNCellBase):
class GRUCell(nn.RNNCellBase): class GRUCell(nn.RNNCellBase):
r""" r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states. it computes the outputs and updates states.
The formula for GRU used is as follows: The formula for GRU used is as follows:
.. math:: .. math::
...@@ -101,8 +101,8 @@ class GRUCell(nn.RNNCellBase): ...@@ -101,8 +101,8 @@ class GRUCell(nn.RNNCellBase):
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator. multiplication operator.
""" """
...@@ -202,7 +202,7 @@ class BiRNNWithBN(nn.Layer): ...@@ -202,7 +202,7 @@ class BiRNNWithBN(nn.Layer):
self.fw_rnn = nn.RNN( self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN( self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D] # x, shape [B, T, D]
...@@ -246,7 +246,7 @@ class BiGRUWithBN(nn.Layer): ...@@ -246,7 +246,7 @@ class BiGRUWithBN(nn.Layer):
self.fw_rnn = nn.RNN( self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN( self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len): def forward(self, x, x_len):
# x, shape [B, T, D] # x, shape [B, T, D]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册