Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
fa211938
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
fa211938
编写于
9月 10, 2021
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Internal change
PiperOrigin-RevId: 396035361
上级
b0707104
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
19 addition
and
93 deletion
+19
-93
official/nlp/keras_nlp/layers/transformer_encoder_block.py
official/nlp/keras_nlp/layers/transformer_encoder_block.py
+2
-19
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+12
-59
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+5
-15
未找到文件。
official/nlp/keras_nlp/layers/transformer_encoder_block.py
浏览文件 @
fa211938
...
...
@@ -116,9 +116,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
def
_maybe_build
(
self
,
inputs
):
super
().
_maybe_build
(
inputs
[:
1
])
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
...
...
@@ -250,9 +247,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
[`query tensor`, `key value tensor`, `attention mask`, `pos_embed`] to
have an additional pos_embed that is added to the query and key of
every self-attention layer.
Returns:
An output tensor with the same dimensions as input/query tensor.
...
...
@@ -261,18 +255,13 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
pos_embed
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
pos_embed
=
None
elif
len
(
inputs
)
==
4
:
input_tensor
,
key_value
,
attention_mask
,
pos_embed
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
,
pos_embed
=
(
inputs
,
None
,
None
,
None
)
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_output_range
:
if
self
.
_norm_first
:
...
...
@@ -293,14 +282,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if
key_value
is
None
:
key_value
=
input_tensor
if
pos_embed
is
None
:
query
=
target_tensor
key
=
key_value
else
:
query
=
target_tensor
+
pos_embed
key
=
key_value
+
pos_embed
attention_output
=
self
.
_attention_layer
(
query
=
query
,
key
=
key
,
value
=
key_value
,
attention_mask
=
attention_mask
)
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
...
...
official/nlp/modeling/layers/transformer.py
浏览文件 @
fa211938
...
...
@@ -232,9 +232,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
else
:
self
.
_cross_attention_cls
=
attention
.
MultiHeadAttention
def
_maybe_build
(
self
,
inputs
):
super
().
_maybe_build
(
inputs
[:
1
])
def
build
(
self
,
input_shape
):
target_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
if
len
(
target_tensor_shape
.
as_list
())
!=
3
:
...
...
@@ -373,57 +370,22 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self
.
intermediate_dense
,
self
.
output_dense
,
self
.
output_layer_norm
]
def
_parse_inputs
(
self
,
inputs
,
multi_channel_cross_attention
):
if
multi_channel_cross_attention
:
if
len
(
inputs
)
<
5
:
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
None
):
if
self
.
multi_channel_cross_attention
:
if
len
(
inputs
)
!=
5
:
raise
ValueError
(
"TransformerDecoderBlock must have
at least
5 inputs, when it uses "
"TransformerDecoderBlock must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d"
%
len
(
inputs
))
elif
len
(
inputs
)
==
5
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
=
inputs
input_pos_embed
=
None
memory_pos_embed
=
None
elif
len
(
inputs
)
==
6
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
=
inputs
memory_pos_embed
=
None
else
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
=
inputs
[:
7
]
else
:
context_attention_weights
=
None
if
len
(
inputs
)
<
4
:
raise
ValueError
(
"TransformerDecoderBlock must have at leaset 4 inputs, but it "
"got: %d"
%
len
(
inputs
))
elif
len
(
inputs
)
==
4
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
input_pos_embed
=
None
memory_pos_embed
=
None
elif
len
(
inputs
)
==
5
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
input_pos_embed
=
inputs
memory_pos_embed
=
None
else
:
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
input_pos_embed
,
memory_pos_embed
=
inputs
[:
6
]
return
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
None
):
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
context_attention_weights
,
input_pos_embed
,
memory_pos_embed
=
self
.
_parse_inputs
(
inputs
,
self
.
multi_channel_cross_attention
)
elif
len
(
inputs
)
!=
4
:
raise
ValueError
(
"TransformerDecoderBlock must have 4 inputs, but it got: %d"
%
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
source_tensor
=
input_tensor
if
self
.
_norm_first
:
input_tensor
=
self
.
self_attention_layer_norm
(
input_tensor
)
if
input_pos_embed
is
None
:
self_attn_query
=
input_tensor
self_attn_key
=
input_tensor
else
:
self_attn_query
=
input_tensor
+
input_pos_embed
self_attn_key
=
input_tensor
+
input_pos_embed
self_attention_output
,
cache
=
self
.
self_attention
(
query
=
self_attn_query
,
key
=
self_attn_key
,
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
self_attention_mask
,
cache
=
cache
,
...
...
@@ -438,22 +400,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
source_self_attention_output
=
self_attention_output
self_attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
)
if
input_pos_embed
is
None
:
cross_attn_query
=
self_attention_output
else
:
cross_attn_query
=
self_attention_output
+
input_pos_embed
if
memory_pos_embed
is
None
:
cross_attn_key
=
memory
else
:
cross_attn_key
=
memory
+
memory_pos_embed
cross_attn_inputs
=
dict
(
query
=
cross_attn_query
,
key
=
cross_attn_key
,
query
=
self_attention_output
,
value
=
memory
,
attention_mask
=
attention_mask
)
if
self
.
multi_channel_cross_attention
:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs
[
"context_attention_weights"
]
=
context_attention_weights
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
...
...
official/nlp/modeling/models/seq2seq_transformer.py
浏览文件 @
fa211938
...
...
@@ -425,7 +425,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
base_config
=
super
(
TransformerEncoder
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
encoder_inputs
,
attention_mask
=
None
,
pos_embed
=
None
):
def
call
(
self
,
encoder_inputs
,
attention_mask
=
None
):
"""Return the output of the encoder.
Args:
...
...
@@ -433,17 +433,14 @@ class TransformerEncoder(tf.keras.layers.Layer):
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
pos_embed: A tensor or a float that is added to the query and key of every
self-attention layer. Defaults to None.
Returns:
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
"""
for
layer_idx
in
range
(
self
.
num_layers
):
encoder_inputs
=
self
.
encoder_layers
[
layer_idx
](
[
encoder_inputs
,
encoder_inputs
,
attention_mask
,
pos_embed
])
[
encoder_inputs
,
attention_mask
])
output_tensor
=
encoder_inputs
output_tensor
=
self
.
output_normalization
(
output_tensor
)
...
...
@@ -522,7 +519,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
attention_initializer
=
attention_initializer
(
input_shape
[
2
]),
name
=
(
"layer_%d"
%
i
)))
self
.
output_normalization
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
self
.
_norm_epsilon
,
dtype
=
"float32"
)
epsilon
=
1e-6
,
dtype
=
"float32"
)
super
(
TransformerDecoder
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
...
...
@@ -548,9 +545,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
cross_attention_mask
=
None
,
cache
=
None
,
decode_loop_step
=
None
,
return_all_decoder_outputs
=
False
,
input_pos_embed
=
None
,
memory_pos_embed
=
None
):
return_all_decoder_outputs
=
False
):
"""Return the output of the decoder layer stacks.
Args:
...
...
@@ -570,10 +565,6 @@ class TransformerDecoder(tf.keras.layers.Layer):
return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss.
input_pos_embed: A tensor or float that is added to the target embedding
in every self-attention and cross-attention layer. Defaults to None.
memory_pos_embed: A tensor or float that is added to the memory embedding
in every cross-attention layer. Defaults to None.
Returns:
Output of decoder.
...
...
@@ -584,8 +575,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
decoder_outputs
=
[]
for
layer_idx
in
range
(
self
.
num_layers
):
transformer_inputs
=
[
output_tensor
,
memory
,
cross_attention_mask
,
self_attention_mask
,
input_pos_embed
,
memory_pos_embed
output_tensor
,
memory
,
cross_attention_mask
,
self_attention_mask
]
# Gets the cache for decoding.
if
cache
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录