Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
1bac3834
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
11 个月 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1bac3834
编写于
3月 08, 2021
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fast weights
上级
6e85f34c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
21 addition
and
33 deletion
+21
-33
labml_nn/transformers/fast_weights/__init__.py
labml_nn/transformers/fast_weights/__init__.py
+21
-33
未找到文件。
labml_nn/transformers/fast_weights/__init__.py
浏览文件 @
1bac3834
...
...
@@ -6,7 +6,6 @@ summary: >
Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch.
---
"""
from
typing
import
Optional
import
torch
from
torch
import
nn
...
...
@@ -61,27 +60,29 @@ class FastWeightAttention(Module):
# Dropout
self
.
dropout
=
nn
.
Dropout
(
dropout_prob
)
def
__call__
(
self
,
x
:
torch
.
Tensor
,
weights
:
Optional
[
torch
.
Tensor
]):
def
__call__
(
self
,
x
:
torch
.
Tensor
):
seq_len
=
x
.
shape
[
0
]
query
=
self
.
sigma
(
self
.
query
(
x
))
key
=
self
.
sigma
(
self
.
key
(
x
))
value
=
self
.
value
(
x
)
beta
=
self
.
gate
(
x
)
if
weights
is
None
:
weights
=
key
.
new_zeros
((
key
.
shape
[
0
],
key
.
shape
[
1
],
value
.
shape
[
2
],
key
.
shape
[
2
]))
value_existing
=
torch
.
einsum
(
'bhvk,bhk->bhv'
,
weights
,
key
)
weights
=
key
.
new_zeros
((
key
.
shape
[
1
],
key
.
shape
[
2
],
value
.
shape
[
3
],
key
.
shape
[
3
]))
outputs
=
[]
beta
=
self
.
gate
(
x
)
for
i
in
range
(
seq_len
):
value_existing
=
torch
.
einsum
(
'bhvk,bhk->bhv'
,
weights
,
key
[
i
])
weights
=
weights
+
torch
.
einsum
(
'bhv,bhk->bhvk'
,
beta
*
(
value
-
value_existing
),
key
)
weights
=
weights
+
torch
.
einsum
(
'bhv,bhk->bhvk'
,
beta
[
i
]
*
(
value
[
i
]
-
value_existing
),
key
[
i
]
)
x
=
torch
.
einsum
(
'bhvk,bhk->bhv'
,
weights
,
query
)
x
=
torch
.
einsum
(
'bhvk,bhk->bhv'
,
weights
,
query
[
i
]
)
# Concatenate multiple heads
x
=
x
.
reshape
(
x
.
shape
[
0
],
-
1
)
# Concatenate multiple heads
outputs
.
append
(
x
.
reshape
(
x
.
shape
[
0
],
-
1
)
)
x
=
torch
.
stack
(
outputs
)
# Output layer
return
self
.
output
(
x
)
,
weights
return
self
.
output
(
x
)
class
FastWeightAttentionTransformerLayer
(
Module
):
...
...
@@ -102,8 +103,8 @@ class FastWeightAttentionTransformerLayer(Module):
self
.
norm_self_attn
=
nn
.
LayerNorm
([
d_model
])
self
.
norm_ff
=
nn
.
LayerNorm
([
d_model
])
def
__call__
(
self
,
x
:
torch
.
Tensor
,
weights
:
Optional
[
torch
.
Tensor
]
):
attn
,
weights
=
self
.
attn
(
x
,
weights
)
def
__call__
(
self
,
x
:
torch
.
Tensor
):
attn
=
self
.
attn
(
x
)
# Add the self attention results
x
=
x
+
self
.
dropout
(
attn
)
...
...
@@ -115,7 +116,7 @@ class FastWeightAttentionTransformerLayer(Module):
x
=
x
+
self
.
dropout
(
ff
)
#
return
x
,
weights
return
x
class
FastWeightAttentionTransformer
(
Module
):
...
...
@@ -126,23 +127,10 @@ class FastWeightAttentionTransformer(Module):
# Final normalization layer
self
.
norm
=
nn
.
LayerNorm
([
layer
.
size
])
def
__call__
(
self
,
x_seq
:
torch
.
Tensor
):
# Split the input to a list along the sequence axis
x_seq
=
torch
.
unbind
(
x_seq
,
dim
=
0
)
# List to store the outputs
res
=
[]
# For each input step
weights
=
[
None
for
_
in
range
(
len
(
self
.
layers
))]
for
x
in
x_seq
:
# Run through each layer
for
i
,
layer
in
enumerate
(
self
.
layers
):
# Get layer output
x
,
weights
[
i
]
=
layer
(
x
,
weights
[
i
])
res
.
append
(
x
)
def
__call__
(
self
,
x
:
torch
.
Tensor
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
# Get layer output
x
=
layer
(
x
)
# Stack the output tensors
res
=
torch
.
stack
(
res
)
# Normalize the output
return
self
.
norm
(
res
)
return
self
.
norm
(
x
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录