Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
3e240cfe
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
11 个月 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
3e240cfe
编写于
9月 04, 2020
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
annotations
上级
89ca5604
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
106 addition
and
19 deletion
+106
-19
labml_nn/transformers/__init__.py
labml_nn/transformers/__init__.py
+9
-0
labml_nn/transformers/mha.py
labml_nn/transformers/mha.py
+28
-8
labml_nn/transformers/relative_mha.py
labml_nn/transformers/relative_mha.py
+69
-11
未找到文件。
labml_nn/transformers/__init__.py
浏览文件 @
3e240cfe
"""
# Transformers
* [Multi-head attention](mha.html)
* [Relative multi-head attention](relative_mha.html)
* [Transformer models](models.html)
* [Fixed positional encoding](positional_encoding.html)
"""
from
.configs
import
TransformerConfigs
from
.models
import
TransformerLayer
,
Encoder
,
Decoder
,
Generator
,
EncoderDecoder
from
.mha
import
MultiHeadAttention
...
...
labml_nn/transformers/mha.py
浏览文件 @
3e240cfe
"""
# Multi-Headed Attention
The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html)
"""
import
math
from
typing
import
Optional
...
...
@@ -10,6 +16,10 @@ from labml_helpers.module import Module
class
PrepareForMultiHeadAttention
(
Module
):
"""
This module does a linear transformation and splits the vector into given
number of heads for multi-head attention.
"""
def
__init__
(
self
,
d_model
:
int
,
heads
:
int
,
d_k
:
int
,
bias
:
bool
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
d_model
,
heads
*
d_k
,
bias
=
bias
)
...
...
@@ -17,22 +27,27 @@ class PrepareForMultiHeadAttention(Module):
self
.
d_k
=
d_k
def
__call__
(
self
,
x
:
torch
.
Tensor
):
# Input has shape `[seq_len, batch_size, d_model]`
seq_len
,
batch_size
,
_
=
x
.
shape
x
=
self
.
linear
(
x
)
x
=
x
.
view
(
seq_len
,
batch_size
,
self
.
heads
,
self
.
d_k
)
# Output has shape `[seq_len, batch_size, heads, d_k]`
return
x
class
MultiHeadAttention
(
Module
):
def
__init__
(
self
,
heads
:
int
,
d_model
:
int
,
dropout_prob
:
float
=
0.1
,
bias
:
bool
=
True
):
"""
##
# Multi-Head Attention
##
Multi-Head Attention Module
This computes multi-headed attention for given `query`, `key` and `value` vectors.
`heads` is the number of heads.
`d_model` is the number of features in the `query`, `key` and `value` vectors.
$$Attention(Q, K, V) = softmax\Bigg(
\f
rac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
"""
super
().
__init__
()
...
...
@@ -54,8 +69,12 @@ class MultiHeadAttention(Module):
def
get_scores
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
):
"""
### Calculate scores between queries and keys
### Calculate scores between queries and keys.
This method can be overriden for other variations like relative attention.
"""
# Calculate $Q K^T$
return
torch
.
einsum
(
'ibhd,jbhd->ijbh'
,
query
,
key
)
def
__call__
(
self
,
*
,
...
...
@@ -69,10 +88,10 @@ class MultiHeadAttention(Module):
if
mask
is
not
None
:
# `mask` has shape `[seq_len, seq_len, batch_size]`,
# where first dimension is the query dimension.
# If the query dimension is equal to $
`$ it will be broadcasted to match
# If the query dimension is equal to $
1$ it will be broadcasted
assert
mask
.
shape
[
0
]
==
1
or
mask
.
shape
[
0
]
==
mask
.
shape
[
1
]
# Same mask applied to all
`h`
heads.
# Same mask applied to all heads.
mask
=
mask
.
unsqueeze
(
-
1
)
# Prepare `query`, `key` and `value` for attention computation
...
...
@@ -81,17 +100,18 @@ class MultiHeadAttention(Module):
key
=
self
.
key
(
key
)
value
=
self
.
value
(
value
)
# Compute attention scores
# Compute attention scores $Q K^T$
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
scores
=
self
.
get_scores
(
query
,
key
)
# Scale scores
# Scale scores
$\frac{Q K^T}{\sqrt{d_k}}$
scores
*=
self
.
scale
# Apply mask
if
mask
is
not
None
:
scores
=
scores
.
masked_fill
(
mask
==
0
,
-
1e9
)
# $softmax$ attention
# $softmax$ attention
$softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
attn
=
F
.
softmax
(
scores
,
dim
=
1
)
# Save attentions if debugging
...
...
@@ -100,7 +120,7 @@ class MultiHeadAttention(Module):
# Apply dropout
attn
=
self
.
dropout
(
attn
)
#
Calculate the attention results
#
Multiply by values $softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$
x
=
torch
.
einsum
(
"ijbh,jbhd->ibhd"
,
attn
,
value
)
# Save attentions for any other calculations
...
...
labml_nn/transformers/relative_mha.py
浏览文件 @
3e240cfe
"""
Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
https://arxiv.org/abs/1901.02860
# Relative Multi-head Attention
This is an implementation of
[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860)
"""
import
torch
from
torch
import
nn
from
labml.logger
import
inspect
from
.mha
import
MultiHeadAttention
from
labml_nn.transformers.mha
import
MultiHeadAttention
def
shift_right
(
x
:
torch
.
Tensor
):
"""
This method shifts $i^{th}$ row of a matrix by $i$ columns.
def
relative_shift
(
x
:
torch
.
Tensor
):
If the input is `[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]`, the shifted
result would be `[[1, 2 ,3], [0, 4, 5], [9, 0, 7]]`.
*Ideally we should mask out the lower triangle but it's ok for our purpose*.
"""
# Concatenate a column of zeros
zero_pad
=
x
.
new_zeros
(
x
.
shape
[
0
],
1
,
*
x
.
shape
[
2
:])
x_padded
=
torch
.
cat
([
x
,
zero_pad
],
dim
=
1
)
# Remove excess elements from the end
x_padded
=
x_padded
.
view
(
x
.
shape
[
1
]
+
1
,
x
.
shape
[
0
],
*
x
.
shape
[
2
:])
x
=
x_padded
[:
-
1
].
view_as
(
x
)
return
x
class
RelativeMultiHeadAttention
(
MultiHeadAttention
):
"""
## Relative Multi-Head Attention Module
We override [Multi-Head Attention](mha.html) module so we only need to
write the `get_scores` method.
"""
def
__init__
(
self
,
heads
:
int
,
d_model
:
int
,
dropout_prob
:
float
=
0.1
):
# The linear transformations doesn't need a bias since we take care of it when
# calculating scores.
# However having a bias for `value` might make sense.
super
().
__init__
(
heads
,
d_model
,
dropout_prob
,
False
)
self
.
P
=
2
**
12
self
.
key_pos_embeddings
=
nn
.
Parameter
(
torch
.
zeros
((
self
.
P
*
2
,
heads
,
self
.
d_k
)),
requires_grad
=
True
)
...
...
@@ -31,27 +53,63 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
self
.
key_pos_bias
=
nn
.
Parameter
(
torch
.
zeros
((
self
.
P
*
2
,
heads
)),
requires_grad
=
True
)
def
get_scores
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
):
"""
With absolute attention
\b
egin{align}
A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j)
\\
&= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j
\end{align}
where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of
orginal embeddings and positional encodings.
They reason out that the attention to a given key should be the same regardless of
the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$.
🤔 May be worthwhile testing without this assumption.
For the second and third terms relative positional encodings are introduced.
So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$.
\b
egin{align}
A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j}
\end{align}
"""
# $R_{i-j}$ pre-shift
key_pos_emb
=
self
.
key_pos_embeddings
[
self
.
P
-
query
.
shape
[
0
]:
self
.
P
+
key
.
shape
[
0
]]
# $S_{i-j}$ pre-shift
key_pos_bias
=
self
.
key_pos_bias
[
self
.
P
-
query
.
shape
[
0
]:
self
.
P
+
key
.
shape
[
0
]]
# $v^T$
query_pos_bias
=
self
.
query_pos_bias
[
None
,
None
,
:,
:]
ac
=
torch
.
einsum
(
'ibhd,jbhd->ijbh'
,
query
+
self
.
query_pos_bias
[
None
,
None
,
:,
:],
key
)
# $Q_i^T K_j + v^T K_j$
ac
=
torch
.
einsum
(
'ibhd,jbhd->ijbh'
,
query
+
query_pos_bias
,
key
)
# $Q_i^T R_{i - j}$ pre-shift
b
=
torch
.
einsum
(
'ibhd,jhd->ijbh'
,
query
,
key_pos_emb
)
# $S_{i-j}$ pre-shift
d
=
key_pos_bias
[
None
,
:,
None
,
:]
bd
=
relative_shift
(
b
+
d
)
# $Q_i^T R_{i - j} + S_{i-j}$
bd
=
shift_right
(
b
+
d
)
bd
=
bd
[:,
-
key
.
shape
[
0
]:]
return
ac
+
bd
def
_test_relative_shift
():
def
_test_shift_right
():
x
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
],
[
7
,
8
,
9
]])
inspect
(
x
)
inspect
(
shift_right
(
x
))
x
=
torch
.
arange
(
1
,
6
)[
None
,
:,
None
,
None
].
repeat
(
5
,
1
,
1
,
1
)
inspect
(
x
[:,
:,
0
,
0
])
inspect
(
relative_shif
t
(
x
)[:,
:,
0
,
0
])
inspect
(
shift_righ
t
(
x
)[:,
:,
0
,
0
])
x
=
torch
.
arange
(
1
,
6
)[
None
,
:,
None
,
None
].
repeat
(
3
,
1
,
1
,
1
)
inspect
(
x
[:,
:,
0
,
0
])
inspect
(
relative_shif
t
(
x
)[:,
:,
0
,
0
])
inspect
(
shift_righ
t
(
x
)[:,
:,
0
,
0
])
if
__name__
==
'__main__'
:
_test_
relative_shif
t
()
_test_
shift_righ
t
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录