提交 1bac3834 编写于 作者: V Varuna Jayasiri

fast weights

上级 6e85f34c
......@@ -6,7 +6,6 @@ summary: >
Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch.
---
"""
from typing import Optional
import torch
from torch import nn
......@@ -61,27 +60,29 @@ class FastWeightAttention(Module):
# Dropout
self.dropout = nn.Dropout(dropout_prob)
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
def __call__(self, x: torch.Tensor):
seq_len = x.shape[0]
query = self.sigma(self.query(x))
key = self.sigma(self.key(x))
value = self.value(x)
beta = self.gate(x)
if weights is None:
weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
weights = key.new_zeros((key.shape[1], key.shape[2], value.shape[3], key.shape[3]))
outputs = []
beta = self.gate(x)
for i in range(seq_len):
value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])
weights = weights + torch.einsum('bhv,bhk->bhvk', beta * (value - value_existing), key)
weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])
x = torch.einsum('bhvk,bhk->bhv', weights, query)
x = torch.einsum('bhvk,bhk->bhv', weights, query[i])
# Concatenate multiple heads
x = x.reshape(x.shape[0], -1)
# Concatenate multiple heads
outputs.append(x.reshape(x.shape[0], -1))
x = torch.stack(outputs)
# Output layer
return self.output(x), weights
return self.output(x)
class FastWeightAttentionTransformerLayer(Module):
......@@ -102,8 +103,8 @@ class FastWeightAttentionTransformerLayer(Module):
self.norm_self_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
attn, weights = self.attn(x, weights)
def __call__(self, x: torch.Tensor):
attn = self.attn(x)
# Add the self attention results
x = x + self.dropout(attn)
......@@ -115,7 +116,7 @@ class FastWeightAttentionTransformerLayer(Module):
x = x + self.dropout(ff)
#
return x, weights
return x
class FastWeightAttentionTransformer(Module):
......@@ -126,23 +127,10 @@ class FastWeightAttentionTransformer(Module):
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
def __call__(self, x_seq: torch.Tensor):
# Split the input to a list along the sequence axis
x_seq = torch.unbind(x_seq, dim=0)
# List to store the outputs
res = []
# For each input step
weights = [None for _ in range(len(self.layers))]
for x in x_seq:
# Run through each layer
for i, layer in enumerate(self.layers):
# Get layer output
x, weights[i] = layer(x, weights[i])
res.append(x)
def __call__(self, x: torch.Tensor):
for i, layer in enumerate(self.layers):
# Get layer output
x = layer(x)
# Stack the output tensors
res = torch.stack(res)
# Normalize the output
return self.norm(res)
return self.norm(x)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册