Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
b7f6e079
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b7f6e079
编写于
9月 09, 2020
作者:
A
Allen Wang
提交者:
A. Unique TensorFlower
9月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MultiHeadRelativeAttention compatibility changes with XLNet
PiperOrigin-RevId: 330751568
上级
da4aca1c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
86 addition
and
44 deletion
+86
-44
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+86
-44
未找到文件。
official/nlp/modeling/layers/attention.py
浏览文件 @
b7f6e079
...
...
@@ -20,14 +20,29 @@ import string
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
masked_softmax
EinsumDense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
MultiHeadAttention
=
tf
.
keras
.
layers
.
MultiHeadAttention
_CHR_IDX
=
string
.
ascii_lowercase
def
_large_compatible_negative
(
tensor_type
):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if
tensor_type
==
tf
.
float16
:
return
tf
.
float16
.
min
return
-
1e9
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
CachedAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""Attention layer with cache used for auto-agressive decoding.
...
...
@@ -116,14 +131,15 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention):
def
_rel_shift
(
x
,
klen
=-
1
):
"""Performs relative shift to form the relative attention score."""
x
=
tf
.
transpose
(
x
,
perm
=
[
1
,
2
,
0
,
3
])
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
3
,
0
,
1
])
x_size
=
tf
.
shape
(
x
)
x
=
tf
.
reshape
(
x
,
[
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
1
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
x
=
tf
.
reshape
(
x
,
[
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
0
,
0
,
0
,
0
],
[
-
1
,
klen
,
-
1
,
-
1
])
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
0
,
1
,
3
])
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
3
,
0
,
1
])
return
x
...
...
@@ -200,15 +216,17 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
to certain positions.
"""
def
__init__
(
self
,
kernel_initializer
=
"variance_scaling"
,
**
kwargs
):
super
().
__init__
(
kernel_initializer
=
kernel_initializer
,
**
kwargs
)
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
(
MultiHeadRelativeAttention
,
self
).
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
hasattr
(
query
,
"shape"
):
query_shape
=
tf
.
TensorShape
(
query
.
shape
)
else
:
query_shape
=
query
if
hasattr
(
value
,
"shape"
):
value_shape
=
tf
.
TensorShape
(
value
.
shape
)
else
:
...
...
@@ -230,36 +248,16 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
bias_constraint
=
self
.
_bias_constraint
)
with
tf
.
init_scope
():
free_dims
=
query_shape
.
rank
-
1
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
einsum_equation
,
_
,
output_rank
=
_build_proj_equation
(
key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_encoding_dense
=
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
bias_axes
=
None
,
name
=
"encoding"
,
**
common_kwargs
)
output_shape
=
[
query_shape
[
-
1
]]
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
2
,
output_dims
=
len
(
output_shape
))
# TODO(allencwang) - replace all einsums with programmatic equations.
einsum_equation
=
"abcd,ecd->abe"
self
.
_output_dense
=
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
output_shape
),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"attention_output"
,
**
common_kwargs
)
def
_build_attention
(
self
,
rank
):
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
],
normalization_axes
=
[
2
])
self
.
_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
compute_attention
(
self
,
query
,
key
,
...
...
@@ -267,6 +265,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
position
,
content_attention_bias
,
positional_attention_bias
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
attention_mask
=
None
):
"""Computes the attention.
...
...
@@ -282,33 +283,59 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional trainable `Tensor` representing the
segmentation encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: Multi-headed output of attention computation of shape
`[B,
T
, N, key_dim]`.
`[B,
S
, N, key_dim]`.
"""
content_attention
=
tf
.
einsum
(
"bind,bjnd->bijn"
,
query
+
content_attention_bias
,
key
)
content_attention
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
+
content_attention_bias
)
positional_attention
=
tf
.
einsum
(
self
.
_dot_product_equation
,
position
,
query
+
positional_attention_bias
)
positional_attention
=
_rel_shift
(
positional_attention
,
klen
=
tf
.
shape
(
content_attention
)[
3
])
if
segment_matrix
is
not
None
:
segment_attention
=
tf
.
einsum
(
"bind,snd->bnis"
,
query
+
segment_attention_bias
,
segment_encoding
)
target_shape
=
tf
.
shape
(
positional_attention
)
segment_attention
=
tf
.
where
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
segment_matrix
,
1
),
target_shape
),
tf
.
broadcast_to
(
segment_attention
[:,
:,
:,
1
:],
target_shape
),
tf
.
broadcast_to
(
segment_attention
[:,
:,
:,
:
1
],
target_shape
))
attention_sum
=
(
content_attention
+
positional_attention
+
segment_attention
)
else
:
attention_sum
=
content_attention
+
positional_attention
positional_attention
=
tf
.
einsum
(
"bind,bjnd->bijn"
,
query
+
positional_attention_bias
,
position
)
attention_scores
=
tf
.
multiply
(
attention_sum
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
positional_attention
=
_rel_shift
(
positional_attention
,
klen
=
tf
.
shape
(
content_attention
)[
2
])
# `attention_scores`: `[B, N, S, S + M]`
if
attention_mask
is
not
None
:
attention_scores
+=
(
_large_compatible_negative
(
attention_scores
.
dtype
)
*
attention_mask
)
attention_scores
=
tf
.
multiply
((
content_attention
+
positional_attention
),
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
attention_scores
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
3
)
attention_output
=
self
.
_dropout_layer
(
attention_scores
)
attention_output
=
tf
.
einsum
(
"bijn,bjnd->bind"
,
attention_output
,
value
)
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_output
,
value
)
return
attention_output
def
call
(
self
,
...
...
@@ -318,6 +345,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
positional_attention_bias
,
key
=
None
,
relative_position_encoding
=
None
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
state
=
None
,
attention_mask
=
None
):
"""Compute multi-head relative attention over inputs.
...
...
@@ -342,6 +372,13 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
key: attention input.
relative_position_encoding: relative positional encoding for key and
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention
...
...
@@ -381,7 +418,12 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
position
=
position
,
content_attention_bias
=
content_attention_bias
,
positional_attention_bias
=
positional_attention_bias
,
segment_matrix
=
segment_matrix
,
segment_encoding
=
segment_encoding
,
segment_attention_bias
=
segment_attention_bias
,
attention_mask
=
attention_mask
)
# `attention_output` = [B, S, N, H]
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录