Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3e195d86
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3e195d86
编写于
1月 22, 2018
作者:
Y
ying
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add wrapper for multihead_attention.
上级
430fdc52
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
242 addition
and
51 deletion
+242
-51
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+43
-18
python/paddle/v2/fluid/layers/ops.py
python/paddle/v2/fluid/layers/ops.py
+15
-4
python/paddle/v2/fluid/nets.py
python/paddle/v2/fluid/nets.py
+81
-29
python/paddle/v2/fluid/tests/test_multihead_attention.py
python/paddle/v2/fluid/tests/test_multihead_attention.py
+103
-0
未找到文件。
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
3e195d86
...
...
@@ -22,13 +22,38 @@ from ..param_attr import ParamAttr
from
tensor
import
concat
__all__
=
[
'fc'
,
'embedding'
,
'dynamic_lstm'
,
'gru_unit'
,
'linear_chain_crf'
,
'crf_decoding'
,
'cos_sim'
,
'cross_entropy'
,
'square_error_cost'
,
'accuracy'
,
'chunk_eval'
,
'sequence_conv'
,
'conv2d'
,
'sequence_pool'
,
'pool2d'
,
'batch_norm'
,
'beam_search_decode'
,
'conv2d_transpose'
,
'sequence_expand'
,
'lstm_unit'
,
'reduce_sum'
,
'reduce_mean'
,
'reduce_max'
,
'reduce_min'
,
'sequence_first_step'
,
'sequence_last_step'
,
'dropout'
,
'split'
,
'l2_normalize'
,
'matmul'
,
'warpctc'
,
'sequence_reshape'
'fc'
,
'embedding'
,
'dynamic_lstm'
,
'gru_unit'
,
'linear_chain_crf'
,
'crf_decoding'
,
'cos_sim'
,
'cross_entropy'
,
'square_error_cost'
,
'accuracy'
,
'chunk_eval'
,
'sequence_conv'
,
'conv2d'
,
'sequence_pool'
,
'pool2d'
,
'batch_norm'
,
'beam_search_decode'
,
'conv2d_transpose'
,
'sequence_expand'
,
'lstm_unit'
,
'reduce_sum'
,
'reduce_mean'
,
'reduce_max'
,
'reduce_min'
,
'sequence_first_step'
,
'sequence_last_step'
,
'dropout'
,
'split'
,
'l2_normalize'
,
'matmul'
,
'warpctc'
,
'sequence_reshape'
,
]
...
...
@@ -43,14 +68,14 @@ def fc(input,
**Fully Connected Layer**
The fully connected layer can take multiple tensors as its inputs. It
creates a variable (one for each input tensor) called weights for each
input
tensor, which represents a fully connected weight matrix from each input
unit to each output unit. The fully connected layer multiplies each input
tensor with its coresponding weight to produce an output Tensor. If
multiple input tensors are given, the results of multiple multiplications
will be sumed up. If bias_attr is not None, a biases variable will be
created and added to the output. Finally, if activation is not None
,
it will be applied to the output as well.
creates a variable (one for each input tensor) called weights for each
input tensor, which represents a fully connected weight matrix from
each input unit to each output unit. The fully connected layer
multiplies each input tensor with its coresponding weight to produce
an output Tensor. If multiple input tensors are given, the results of
multiple multiplications will be sumed up. If bias_attr is not None,
a biases variable will be created and added to the output. Finally
,
i
f activation is not None, i
t will be applied to the output as well.
This process can be formulated as follows:
...
...
@@ -1813,11 +1838,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
- If both are 2-D, they are multiplied like conventional matrices.
- If either is n-D, it is treated as a stack of matrices residing in the
last two dimensions and a batched matrix multiply supporting broadcast
last two dimensions and a batched matrix multiply supporting broadcast
applies on the two tensors.
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
nontransposed, the prepended or appended dimension :math:`1` will be
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
nontransposed, the prepended or appended dimension :math:`1` will be
removed after matrix multiplication.
Args:
...
...
python/paddle/v2/fluid/layers/ops.py
浏览文件 @
3e195d86
...
...
@@ -46,10 +46,21 @@ __activations__ = [
]
__all__
=
[
'mean'
,
'mul'
,
'reshape'
,
'scale'
,
'transpose'
,
'sigmoid_cross_entropy_with_logits'
,
'elementwise_add'
,
'elementwise_div'
,
'elementwise_sub'
,
'elementwise_mul'
,
'elementwise_max'
,
'elementwise_min'
,
'clip'
,
'clip_by_norm'
,
'sequence_softmax'
'mean'
,
'mul'
,
'reshape'
,
'scale'
,
'transpose'
,
'sigmoid_cross_entropy_with_logits'
,
'elementwise_add'
,
'elementwise_div'
,
'elementwise_sub'
,
'elementwise_mul'
,
'elementwise_max'
,
'elementwise_min'
,
'clip'
,
'clip_by_norm'
,
'sequence_softmax'
,
]
+
__activations__
for
_OP
in
set
(
__all__
):
...
...
python/paddle/v2/fluid/nets.py
浏览文件 @
3e195d86
...
...
@@ -127,21 +127,21 @@ def sequence_conv_pool(input,
def
glu
(
input
,
dim
=-
1
):
"""
The gated linear unit composed by split, sigmoid activation and elementwise
multiplication. Specifically, Split the input into two equal sized parts
:math:`a` and :math:`b` along the given dimension and then compute as
The gated linear unit composed by split, sigmoid activation and elementwise
multiplication. Specifically, Split the input into two equal sized parts
:math:`a` and :math:`b` along the given dimension and then compute as
following:
.. math::
{GLU}(a, b)= a \otimes \sigma(b)
Refer to `Language Modeling with Gated Convolutional Networks
Refer to `Language Modeling with Gated Convolutional Networks
<https://arxiv.org/pdf/1612.08083.pdf>`_.
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
dim (int): The dimension along which to split. If :math:`dim < 0`, the
dim (int): The dimension along which to split. If :math:`dim < 0`, the
dimension to split along is :math:`rank(input) + dim`.
Returns:
...
...
@@ -160,53 +160,105 @@ def glu(input, dim=-1):
return
out
def
dot_product_attention
(
querys
,
keys
,
values
):
def
scaled_dot_product_attention
(
queries
,
keys
,
values
,
num_heads
,
dropout_rate
=
0.
):
"""
The dot-product attention.
Attention mechanism can be seen as mapping a query and a set of key-value
pairs to an output. The output is computed as a weighted sum of the values,
where the weight assigned to each value is computed by a compatibility
function (dot-product here) of the query with the corresponding key.
The dot-product attention can be implemented through (batch) matrix
Attention mechanism can be seen as mapping a query and a set of
key-value pairs to an output. The output is computed as a weighted sum
of the values, where the weight assigned to each value is computed by a
compatibility function (dot-product here) of the query with the
corresponding key.
The dot-product attention can be implemented through (batch) matrix
multipication as follows:
.. math::
Attention(Q, K, V)= softmax(QK^\mathrm{T})V
Refer to `Attention Is All You Need
Refer to `Attention Is All You Need
<https://arxiv.org/pdf/1706.03762.pdf>`_.
Note that batch data containing sequences with different lengths is not
Note that batch data containing sequences with different lengths is not
supported by this because of the (batch) matrix multipication.
Args:
query (Variable): The input variable which is a Tensor or LoDTensor.
query (Variable): The input variable which is a Tensor or
LoDTensor.
key (Variable): The input variable which is a Tensor or LoDTensor.
value (Variable): The input variable which is a Tensor or LoDTensor.
value (Variable): The input variable which is a Tensor or
LoDTensor.
Returns:
tuple: The Tensor variables representing the output and attention scores.
tuple: The Tensor variables representing the output and attention
scores.
Examples:
.. code-block:: python
# Suppose q, k, v are tensor variables with the following
shape:
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
# Suppose q, k, v are tensor variables with the following
#
shape:
q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
out, attn_scores = fluid.nets.dot_product_attention(q, k, v)
out.shape # [3, 5, 10]
attn_scores.shape # [3, 5, 6]
"""
assert
keys
.
shape
[
-
2
]
==
values
.
shape
[
-
2
],
'The shapes of keys and values mismatch.'
assert
querys
.
shape
[
-
1
]
==
keys
.
shape
[
-
1
],
'The shapes of querys and keys mismatch.'
product
=
layers
.
matmul
(
x
=
querys
,
y
=
keys
,
transpose_y
=
True
)
if
not
(
len
(
queries
.
shape
)
==
len
(
keys
.
shape
)
==
len
(
values
.
shape
)
==
3
):
raise
ValueError
(
"Inputs quries, keys and values should all be 3-D tensors."
)
if
queries
.
shape
[
-
1
]
!=
keys
.
shape
[
-
1
]:
raise
ValueError
(
"The hidden size of queries and keys should be the same."
)
if
keys
.
shape
[
-
2
]
!=
values
.
shape
[
-
2
]:
raise
ValueError
(
"The max sequence length in query batch and in key batch "
"should be the same."
)
if
keys
.
shape
[
-
1
]
%
num_heads
!=
0
:
raise
ValueError
(
"The hidden size of keys (%d) must be divisible "
"by the number of attention heads (%d)."
%
(
keys
.
shape
[
-
1
],
num_heads
))
if
values
.
shape
[
-
1
]
%
num_heads
!=
0
:
raise
ValueError
(
"The hidden size of values (%d) must be divisible "
"by the number of attention heads (%d)."
%
(
values
.
shape
[
-
1
],
num_heads
))
def
__split_heads
(
x
,
num_heads
):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions.
Args:
x(Tensor): a 3-D input Tensor.
num_heads(int): The number of heads.
Returns:
a Tensor with shape [..., n, m/n]
"""
hidden_size
=
x
.
shape
[
-
1
]
#
reshaped
=
layers
.
reshape
(
x
=
x
,
shape
=
x
.
shape
[:
-
1
]
+
[
num_heads
,
hidden_size
//
num_heads
])
pass
def
__combine_heads
():
pass
q
=
__split_heads
(
quries
,
num_heads
)
k
=
__split_heads
(
keys
,
num_heads
)
v
=
__split_heads
(
values
,
num_heads
)
key_dim_per_head
=
keys
.
shape
[
-
1
]
//
num_heads
scale
=
key_dim_per_head
**-
0.5
product
=
layers
.
matmul
(
x
=
k
,
y
=
q
,
transpose_y
=
True
)
attn_scores
=
layers
.
reshape
(
x
=
layers
.
reshape
(
x
=
product
,
shape
=
[
-
1
,
product
.
shape
[
-
1
]],
act
=
'softmax'
),
x
=
product
,
shape
=
[
-
1
,
product
.
shape
[
-
1
]],
act
=
"softmax"
),
shape
=
product
.
shape
)
ou
t
=
layers
.
matmul
(
attn_scores
,
values
)
return
ou
t
,
attn_scores
contex
t
=
layers
.
matmul
(
attn_scores
,
values
)
return
contex
t
,
attn_scores
python/paddle/v2/fluid/tests/test_multihead_attention.py
0 → 100644
浏览文件 @
3e195d86
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.v2.fluid
as
fluid
import
paddle.v2.fluid.core
as
core
import
numpy
as
np
import
pdb
class
TestMultiheadAttention
(
unittest
.
TestCase
):
def
gen_random_input
(
self
):
"""Generate random input data.
"""
# batch_size, max_sequence_length, hidden dimension
self
.
input_shape
=
(
3
,
13
,
16
)
self
.
queries
=
np
.
random
.
random
(
size
=
self
.
input_shape
).
astype
(
"float32"
)
self
.
keys
=
np
.
random
.
random
(
size
=
self
.
input_shape
).
astype
(
"float32"
)
def
set_program
(
self
):
"""Build the test program.
"""
queries
=
fluid
.
layers
.
data
(
name
=
"queries"
,
shape
=
self
.
input_shape
,
dtype
=
"float32"
,
append_batch_size
=
False
)
queries
.
stop_gradient
=
False
keys
=
fluid
.
layers
.
data
(
name
=
"keys"
,
shape
=
self
.
input_shape
,
dtype
=
"float32"
,
append_batch_size
=
False
)
keys
.
stop_gradient
=
False
contexts
,
att_scores
=
fluid
.
nets
.
scaled_dot_product_attention
(
queries
=
queries
,
keys
=
keys
,
values
=
keys
,
num_heads
=
8
,
dropout_rate
=
0.
)
out
=
fluid
.
layers
.
reduce_sum
(
contexts
,
dim
=
None
)
fluid
.
backward
.
append_backward
(
loss
=
out
)
self
.
fetch_list
=
[
contexts
]
def
run_program
(
self
):
"""Run the test program.
"""
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compile_gpu
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
set_inputs
(
place
)
exe
=
fluid
.
Executor
(
place
)
output
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
self
.
inputs
,
fetch_list
=
self
.
fetch_list
,
return_numpy
=
True
)
self
.
op_output
=
output
def
set_inputs
(
self
,
place
):
"""Set the randomly generated data to the test program.
"""
self
.
inputs
=
{}
queries
=
fluid
.
Tensor
()
queries
.
set
(
self
.
queries
,
place
)
keys
=
fluid
.
Tensor
()
keys
.
set
(
self
.
keys
,
place
)
self
.
inputs
[
"keys"
]
=
keys
self
.
inputs
[
"values"
]
=
values
def
test_multihead_attention
(
self
):
self
.
gen_random_input
()
self
.
set_program
()
pdb
.
set_trace
()
self
.
run_program
()
expect_output
=
self
.
l2_normalize
(
self
.
data
,
axis
,
epsilon
)
# check output
self
.
assertTrue
(
np
.
allclose
(
self
.
op_output
,
expect_output
,
atol
=
0.001
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录