Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9396c6d9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9396c6d9
编写于
1月 23, 2018
作者:
Y
ying
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs.
上级
3be6c736
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
14 addition
and
21 deletion
+14
-21
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+1
-3
python/paddle/v2/fluid/nets.py
python/paddle/v2/fluid/nets.py
+11
-8
python/paddle/v2/fluid/tests/test_multihead_attention.py
python/paddle/v2/fluid/tests/test_multihead_attention.py
+2
-10
未找到文件。
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
9396c6d9
...
...
@@ -21,8 +21,6 @@ from ..framework import Variable
from
..param_attr
import
ParamAttr
from
tensor
import
concat
import
pdb
__all__
=
[
'fc'
,
'embedding'
,
...
...
@@ -1966,7 +1964,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
__check_input
(
x
,
y
)
helper
=
LayerHelper
(
'matmul'
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
helper
.
input_dtype
()
)
out
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'matmul'
,
inputs
=
{
'X'
:
x
,
...
...
python/paddle/v2/fluid/nets.py
浏览文件 @
9396c6d9
...
...
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pdb
import
layers
__all__
=
[
...
...
@@ -163,7 +162,7 @@ def glu(input, dim=-1):
def
scaled_dot_product_attention
(
queries
,
keys
,
values
,
num_heads
,
num_heads
=
1
,
dropout_rate
=
0.
):
"""
The dot-product attention.
...
...
@@ -259,9 +258,12 @@ def scaled_dot_product_attention(queries,
raise
ValueError
(
"Input(x) should be a 4-D Tensor."
)
trans_x
=
layers
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
,
3
])
return
layers
.
reshape
(
x
=
layers
.
reshape
(
return
layers
.
reshape
(
x
=
trans_x
,
shape
=
[
trans_x
.
shape
[
0
],
trans_x
[
1
],
trans_x
[
2
]
*
trans_x
[
3
]]))
shape
=
map
(
int
,
[
trans_x
.
shape
[
0
],
trans_x
.
shape
[
1
],
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]
]))
q
=
__split_heads
(
queries
,
num_heads
)
k
=
__split_heads
(
keys
,
num_heads
)
...
...
@@ -271,10 +273,11 @@ def scaled_dot_product_attention(queries,
scaled_q
=
layers
.
scale
(
x
=
q
,
scale
=
key_dim_per_head
**-
0.5
)
product
=
layers
.
matmul
(
x
=
k
,
y
=
scaled_q
,
transpose_y
=
True
)
attn_score
s
=
layers
.
reshape
(
weight
s
=
layers
.
reshape
(
x
=
layers
.
reshape
(
x
=
product
,
shape
=
[
-
1
,
product
.
shape
[
-
1
]],
act
=
"softmax"
),
shape
=
product
.
shape
)
ctx_multiheads
=
layers
.
matmul
(
attn_scores
,
v
)
context
=
__combine_heads
(
ctx_multiheads
)
return
context
if
dropout_rate
:
weights
=
layers
.
dropout
(
x
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
ctx_multiheads
=
layers
.
matmul
(
weights
,
v
)
return
__combine_heads
(
ctx_multiheads
)
python/paddle/v2/fluid/tests/test_multihead_attention.py
浏览文件 @
9396c6d9
...
...
@@ -17,8 +17,6 @@ import paddle.v2.fluid as fluid
import
paddle.v2.fluid.core
as
core
import
numpy
as
np
import
pdb
class
TestMultiheadAttention
(
unittest
.
TestCase
):
def
gen_random_input
(
self
):
...
...
@@ -45,7 +43,7 @@ class TestMultiheadAttention(unittest.TestCase):
append_batch_size
=
False
)
keys
.
stop_gradient
=
False
contexts
,
att_scores
=
fluid
.
nets
.
scaled_dot_product_attention
(
contexts
=
fluid
.
nets
.
scaled_dot_product_attention
(
queries
=
queries
,
keys
=
keys
,
values
=
keys
,
...
...
@@ -84,20 +82,14 @@ class TestMultiheadAttention(unittest.TestCase):
keys
.
set
(
self
.
keys
,
place
)
self
.
inputs
[
"keys"
]
=
keys
self
.
inputs
[
"
values"
]
=
valu
es
self
.
inputs
[
"
queries"
]
=
queri
es
def
test_multihead_attention
(
self
):
self
.
gen_random_input
()
self
.
set_program
()
pdb
.
set_trace
()
self
.
run_program
()
expect_output
=
self
.
l2_normalize
(
self
.
data
,
axis
,
epsilon
)
# check output
self
.
assertTrue
(
np
.
allclose
(
self
.
op_output
,
expect_output
,
atol
=
0.001
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录