Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
56ec40ad
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
56ec40ad
编写于
11月 20, 2017
作者:
C
Cao Ying
提交者:
GitHub
11月 20, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4924 from ranqiu92/attention
Add the configuration helper for multi-head attention.
上级
01d6ccb4
f2240293
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
132 addition
and
4 deletion
+132
-4
python/paddle/trainer_config_helpers/networks.py
python/paddle/trainer_config_helpers/networks.py
+132
-4
未找到文件。
python/paddle/trainer_config_helpers/networks.py
浏览文件 @
56ec40ad
...
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
activations
import
LinearActivation
,
ReluActivation
,
SoftmaxActivation
,
\
IdentityActivation
,
TanhActivation
,
SequenceSoftmaxActivation
...
...
@@ -26,9 +26,9 @@ __all__ = [
'sequence_conv_pool'
,
'simple_lstm'
,
"simple_img_conv_pool"
,
"img_conv_bn_pool"
,
'lstmemory_group'
,
'lstmemory_unit'
,
'small_vgg'
,
'img_conv_group'
,
'vgg_16_network'
,
'gru_unit'
,
'gru_group'
,
'simple_gru'
,
'simple_attention'
,
'dot_product_attention'
,
'
simple_gru2
'
,
'
bidirectional_gru'
,
'text_conv_pool'
,
'bidirectional_lstm'
,
'inputs
'
,
'outputs'
'simple_attention'
,
'dot_product_attention'
,
'
multi_head_attention
'
,
'
simple_gru2'
,
'bidirectional_gru'
,
'text_conv_pool'
,
'bidirectional_lstm
'
,
'
inputs'
,
'
outputs'
]
######################################################
...
...
@@ -1496,6 +1496,134 @@ def dot_product_attention(encoded_sequence,
input
=
scaled
,
pooling_type
=
SumPooling
(),
name
=
"%s_pooling"
%
name
)
@
wrap_name_default
()
def
multi_head_attention
(
query
,
key
,
value
,
key_proj_size
,
value_proj_size
,
head_num
,
attention_type
,
softmax_param_attr
=
None
,
name
=
None
):
"""
Calculate and return a context vector with dot-product attention mechanism.
The dimension of the context vector equals to value_proj_size * head_num.
Please refer to **Attention Is All You Need** for more details. The link is
as follows:
https://arxiv.org/abs/1706.03762.
The example usage is:
.. code-block:: python
context = multi_head_attention(query=decoder_state,
key=enc_seq,
value=enc_seq,
key_proj_size=64,
value_pro_size=64,
head_num=8,
attention_type='dot-product attention')
:param name: A prefix attached to the name of each layer that defined inside
the multi_head_attention.
:type name: basestring
:param softmax_param_attr: The parameter attribute of sequence softmax
that is used to produce attention weight.
:type softmax_param_attr: ParameterAttribute
:param query: query is used to calculate attention weights over values at current step.
:type query: LayerOutput
:param key: key is used to calculate the attention weight of the corresponding value.
:type key: LayerOutput
:param value: value is the sequence to be attended.
:type value: LayerOutput
:param key_proj_size: The dimension of the linear projection performed on key and query.
:type key_proj_size: int
:param value_proj_size: The dimension of the linear projection performed on value.
:type value_proj_size: int
:param head_num: The number of attention heads.
:type head_num: int
:param attention_type: The type of the attention mechanism used in each attention
heads. Now, we only support scaled dot-product attention and
additive attention.
:type attention_type: basestring
:return: The context vector.
:rtype: LayerOutput
"""
assert
attention_type
in
[
'dot-product attention'
,
'additive attention'
]
with
mixed_layer
(
size
=
key_proj_size
*
head_num
,
name
=
'%s_query_proj'
%
name
)
as
query_proj
:
query_proj
+=
full_matrix_projection
(
query
)
query_proj
=
expand_layer
(
input
=
query_proj
,
expand_as
=
key
)
with
mixed_layer
(
size
=
key_proj_size
*
head_num
,
name
=
'%s_key_proj'
%
name
)
as
key_proj
:
key_proj
+=
full_matrix_projection
(
key
)
with
mixed_layer
(
size
=
value_proj_size
*
head_num
,
name
=
'%s_value_proj'
%
name
)
as
value_proj
:
value_proj
+=
full_matrix_projection
(
value
)
head_list
=
[]
for
i
in
range
(
head_num
):
with
mixed_layer
(
size
=
key_proj_size
)
as
sub_query_proj
:
sub_query_proj
+=
identity_projection
(
query_proj
,
offset
=
key_proj_size
*
i
,
size
=
key_proj_size
)
with
mixed_layer
(
size
=
key_proj_size
)
as
sub_key_proj
:
sub_key_proj
+=
identity_projection
(
key_proj
,
offset
=
key_proj_size
*
i
,
size
=
key_proj_size
)
with
mixed_layer
(
size
=
value_proj_size
)
as
sub_value_proj
:
sub_value_proj
+=
identity_projection
(
value_proj
,
offset
=
value_proj_size
*
i
,
size
=
value_proj_size
)
if
attention_type
==
'dot-product attention'
:
m
=
dot_prod_layer
(
input1
=
sub_query_proj
,
input2
=
sub_key_proj
,
name
=
'%s_dot-product_%d'
%
(
name
,
i
))
m
=
slope_intercept_layer
(
input
=
m
,
slope
=
math
.
sqrt
(
1.0
/
key_proj_size
),
name
=
'%s_dot-product_scaling_%d'
%
(
name
,
i
))
else
:
with
mixed_layer
(
size
=
key_proj_size
,
act
=
TanhActivation
(),
name
=
'%s_combine_%d'
%
(
name
,
i
))
as
m
:
m
+=
identity_projection
(
sub_query_proj
)
m
+=
identity_projection
(
sub_key_proj
)
attention_weight
=
fc_layer
(
input
=
m
,
size
=
1
,
act
=
SequenceSoftmaxActivation
(),
param_attr
=
softmax_param_attr
,
name
=
"%s_softmax_%d"
%
(
name
,
i
),
bias_attr
=
False
)
scaled
=
scaling_layer
(
weight
=
attention_weight
,
input
=
sub_value_proj
,
name
=
'%s_scaling_%d'
%
(
name
,
i
))
head
=
pooling_layer
(
input
=
scaled
,
pooling_type
=
SumPooling
(),
name
=
"%s_pooling_%d"
%
(
name
,
i
))
head_list
.
append
(
head
)
attended
=
concat_layer
(
head_list
)
return
attended
def
inputs
(
layers
,
*
args
):
"""
Declare the inputs of network. The order of input should be as same as
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录