提交 7d3c6fcd 编写于 作者: V Varuna Jayasiri

lstm chunk

上级 814fd565
......@@ -31,6 +31,7 @@ class HyperLSTMCell(Module):
self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])
self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
self.layer_norm_c = nn.LayerNorm(hidden_size)
def __call__(self, x: torch.Tensor,
h: torch.Tensor, c: torch.Tensor,
......@@ -69,7 +70,7 @@ class HyperLSTMCell(Module):
c_next = f * c + i * g
# $$h_t = o_t \odot \tanh(c_t)$$
h_next = o * torch.tanh(c_next)
h_next = o * torch.tanh(self.layer_norm_c(c_next))
return h_next, c_next, rhn_h, rhn_c
......
......@@ -55,8 +55,6 @@ class LSTMCell(Module):
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.hidden_size = hidden_size
# These are the linear layer to transform the `input` and `hidden` vectors.
# One of them doesn't need a bias since we add the transformations.
......@@ -68,17 +66,18 @@ class LSTMCell(Module):
def __call__(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
# We compute the linear transformations for $i_t$, $f_t$, $g_t$ and $o_t$
# using the same linear layers.
# Each layer produces an output of 4 times the `hidden_size` and we split them later
ifgo = self.hidden_lin(h) + self.input_lin(x)
# Each layer produces an output of 4 times the `hidden_size` and we split them
ifgo = ifgo.chunk(4, dim=-1)
# $$i_t = \sigma\big(lin_{xi}(x_t) + lin_{hi}(h_{t-1})\big)$$
i = torch.sigmoid(ifgo[:, :self.hidden_size])
i = torch.sigmoid(ifgo[0])
# $$f_t = \sigma\big(lin_{xf}(x_t) + lin_{hf}(h_{t-1})\big)$$
f = torch.sigmoid(ifgo[:, self.hidden_size:self.hidden_size * 2])
f = torch.sigmoid(ifgo[1])
# $$g_t = \tanh\big(lin_{xg}(x_t) + lin_{hg}(h_{t-1})\big)$$
g = torch.tanh(ifgo[:, self.hidden_size * 2:self.hidden_size * 3])
g = torch.tanh(ifgo[2])
# $$o_t = \sigma\big(lin_{xo}(x_t) + lin_{ho}(h_{t-1})\big)$$
o = torch.sigmoid(ifgo[:, self.hidden_size * 3:self.hidden_size * 4])
o = torch.sigmoid(ifgo[3])
# $$c_t = f_t \odot c_{t-1} + i_t \odot g_t$$
c_next = f * c + i * g
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册