Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4df95edc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4df95edc
编写于
1月 17, 2018
作者:
C
Cao Ying
提交者:
GitHub
1月 17, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7602 from guoshengCS/add-dot_product_attention
Add Python wrapper for dot-product-attention.
上级
939e1b1a
db959d63
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
140 addition
and
4 deletion
+140
-4
doc/api/v2/fluid/layers.rst
doc/api/v2/fluid/layers.rst
+6
-0
doc/api/v2/fluid/nets.rst
doc/api/v2/fluid/nets.rst
+6
-0
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+71
-0
python/paddle/v2/fluid/nets.py
python/paddle/v2/fluid/nets.py
+53
-0
python/paddle/v2/fluid/tests/test_matmul_op.py
python/paddle/v2/fluid/tests/test_matmul_op.py
+4
-4
未找到文件。
doc/api/v2/fluid/layers.rst
浏览文件 @
4df95edc
...
...
@@ -364,6 +364,12 @@ split
.. autofunction:: paddle.v2.fluid.layers.split
:noindex:
matmul
------
.. autofunction:: paddle.v2.fluid.layers.matmul
:noindex:
logsigmoid
----------
.. autofunction:: paddle.v2.fluid.layers.logsigmoid
...
...
doc/api/v2/fluid/nets.rst
浏览文件 @
4df95edc
...
...
@@ -25,3 +25,9 @@ glu
.. autofunction:: paddle.v2.fluid.nets.glu
:noindex:
dot_product_attention
---------------------
.. autofunction:: paddle.v2.fluid.nets.dot_product_attention
:noindex:
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
4df95edc
...
...
@@ -50,6 +50,7 @@ __all__ = [
'sequence_last_step'
,
'dropout'
,
'split'
,
'matmul'
,
]
...
...
@@ -1597,3 +1598,73 @@ def split(input, num_or_sections, dim=-1):
'axis'
:
dim
})
return
outs
def
matmul
(
x
,
y
,
transpose_x
=
False
,
transpose_y
=
False
,
name
=
None
):
"""
Applies matrix multipication to two tensors. Currently only rank 1 to rank
3 input tensors are supported.
The actual behavior depends on the shapes of :math:`x`, :math:`y` and the
flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically:
- If a transpose flag is specified, the last two dimensions of the tensor
are transposed. If the tensor is rank-1 of shape :math:`[D]`, then for
:math:`x` it is treated as :math:`[1, D]` in nontransposed form and as
:math:`[D, 1]` in transposed form, whereas for :math:`y` it is the
opposite: It is treated as :math:`[D, 1]` in nontransposed form and as
:math:`[1, D]` in transposed form.
- After transpose, the two tensors are 2-D or 3-D and matrix multipication
performs in the following way.
- If both are 2-D, they are multiplied like conventional matrices.
- If either is 3-D, it is treated as a stack of matrices residing in the
last two dimensions and a batched matrix multiply supporting broadcast
applies on the two tensors.
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
nontransposed, the prepended or appended dimension :math:`1` will be
removed after matrix multipication.
Args:
x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a Tensor or LoDTensor.
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
transpose_y (bool): Whether to transpose :math:`y` before multiplication.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The product Tensor variable.
Examples:
.. code-block:: python
# Examples to clarify shapes of the inputs and output
# x: [B, M, K], y: [B, K, N]
fluid.layers.matmul(x, y) # out: [B, M, N]
# x: [B, M, K], y: [K, N]
fluid.layers.matmul(x, y) # out: [B, M, N]
# x: [B, M, K], y: [K]
fluid.layers.matmul(x, y) # out: [B, M]
# x: [M, K], y: [K, N]
fluid.layers.matmul(x, y) # out: [M, N]
# x: [K], y: [K]
fluid.layers.matmul(x, y) # out: [1]
# x: [M], y: [N]
fluid.layers.matmul(x, y, True, True) # out: [M, N]
"""
helper
=
LayerHelper
(
'matmul'
,
**
locals
())
assert
max
(
len
(
x
.
shape
),
len
(
y
.
shape
)
)
<=
3
,
'Currently only rank 1 to rank 3 input tensors are supported.'
out
=
helper
.
create_tmp_variable
(
dtype
=
helper
.
input_dtype
())
helper
.
append_op
(
type
=
'matmul'
,
inputs
=
{
'X'
:
x
,
'Y'
:
y
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'transpose_X'
:
transpose_x
,
'transpose_Y'
:
transpose_y
})
return
out
python/paddle/v2/fluid/nets.py
浏览文件 @
4df95edc
...
...
@@ -17,6 +17,7 @@ __all__ = [
"simple_img_conv_pool"
,
"sequence_conv_pool"
,
"glu"
,
"dot_product_attention"
,
]
...
...
@@ -150,3 +151,55 @@ def glu(input, dim=-1):
act_b
=
layers
.
sigmoid
(
x
=
b
)
out
=
layers
.
elementwise_mul
(
x
=
a
,
y
=
act_b
)
return
out
def
dot_product_attention
(
querys
,
keys
,
values
):
"""
The dot-product attention.
Attention mechanism can be seen as mapping a query and a set of key-value
pairs to an output. The output is computed as a weighted sum of the values,
where the weight assigned to each value is computed by a compatibility
function (dot-product here) of the query with the corresponding key.
The dot-product attention can be implemented through (batch) matrix
multipication as follows:
.. math::
Attention(Q, K, V)= softmax(QK^\mathrm{T})V
Refer to `Attention Is All You Need
<https://arxiv.org/pdf/1706.03762.pdf>`_.
Note that batch data containing sequences with different lengths is not
supported by this because of the (batch) matrix multipication.
Args:
query (Variable): The input variable which is a Tensor or LoDTensor.
key (Variable): The input variable which is a Tensor or LoDTensor.
value (Variable): The input variable which is a Tensor or LoDTensor.
Returns:
tuple: The Tensor variables representing the output and attention scores.
Examples:
.. code-block:: python
# Suppose q, k, v are tensor variables with the following shape:
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
out, attn_scores = fluid.nets.dot_product_attention(q, k, v)
out.shape # [3, 5, 10]
attn_scores.shape # [3, 5, 6]
"""
assert
keys
.
shape
[
-
2
]
==
values
.
shape
[
-
2
],
'The shapes of keys and values mismatch.'
assert
querys
.
shape
[
-
1
]
==
keys
.
shape
[
-
1
],
'The shapes of querys and keys mismatch.'
product
=
layers
.
matmul
(
x
=
querys
,
y
=
keys
,
transpose_y
=
True
)
attn_scores
=
layers
.
reshape
(
x
=
layers
.
reshape
(
x
=
product
,
shape
=
[
-
1
,
product
.
shape
[
-
1
]],
act
=
'softmax'
),
shape
=
product
.
shape
)
out
=
layers
.
matmul
(
attn_scores
,
values
)
return
out
,
attn_scores
python/paddle/v2/fluid/tests/test_matmul_op.py
浏览文件 @
4df95edc
...
...
@@ -96,18 +96,18 @@ class Generator(object):
self
.
outputs
=
{
'Out'
:
Out
}
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-
2
)
self
.
check_output
(
atol
=
1e-
3
)
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'X'
,
'Y'
],
'Out'
,
max_relative_error
=
0.5
)
self
.
check_grad
([
'X'
,
'Y'
],
'Out'
,
max_relative_error
=
1e-3
)
def
test_check_grad_ignore_x
(
self
):
self
.
check_grad
(
[
'Y'
],
'Out'
,
max_relative_error
=
0.5
,
no_grad_set
=
set
(
"X"
))
[
'Y'
],
'Out'
,
max_relative_error
=
1e-3
,
no_grad_set
=
set
(
"X"
))
def
test_check_grad_ignore_y
(
self
):
self
.
check_grad
(
[
'X'
],
'Out'
,
max_relative_error
=
0.5
,
no_grad_set
=
set
(
'Y'
))
[
'X'
],
'Out'
,
max_relative_error
=
1e-3
,
no_grad_set
=
set
(
'Y'
))
# Generate test cases for all possibilities
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录