Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9dd64d83
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9dd64d83
编写于
3月 26, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
WMT Model
上级
cb40c331
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
660 addition
and
8 deletion
+660
-8
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+10
-7
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+2
-0
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+1
-1
python/paddle/fluid/tests/unittests/.gitignore
python/paddle/fluid/tests/unittests/.gitignore
+1
-0
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+159
-0
python/paddle/fluid/tests/unittests/transformer_model.py
python/paddle/fluid/tests/unittests/transformer_model.py
+487
-0
未找到文件。
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
9dd64d83
...
...
@@ -170,13 +170,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for
(
auto
p
:
this
->
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
// NOTE: the temp scope can be dropped lazily if needed.
// Drop tmp scopes;
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
kid
=
*
scope
->
Var
(
"@TMP_SCOPE@"
)
->
GetMutable
<
Scope
*>
();
kid
=
nullptr
;
scope
->
DropKids
();
for
(
auto
&
drop_fn
:
this
->
drop_functions_
)
{
drop_fn
();
}
};
...
...
@@ -190,6 +185,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
sync_computation
();
}
// NOTE: the temp scope can be dropped lazily if needed.
// Drop tmp scopes;
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
kid
=
*
scope
->
Var
(
"@TMP_SCOPE@"
)
->
GetMutable
<
Scope
*>
();
this
->
drop_functions_
.
emplace_back
([
=
]
{
scope
->
DeleteScope
(
kid
);
});
kid
=
nullptr
;
}
return
fetch_data
;
}
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
9dd64d83
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
...
...
@@ -51,6 +52,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
size_t
computation_count_
{
0
};
size_t
max_async_computation
{
100
};
std
::
vector
<
std
::
function
<
void
()
>>
drop_functions_
;
};
}
// namespace details
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
9dd64d83
...
...
@@ -29,7 +29,7 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
PADDLE_ENFORCE_EQ
(
actual
.
size
(),
expect
.
size
());
for
(
int
j
=
0
;
j
<
actual
.
size
();
++
j
)
{
PADDLE_ENFORCE
(
actual
[
i
]
==
expect
[
i
]
||
expect
[
i
]
==
-
1
);
//
PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
}
}
}
...
...
python/paddle/fluid/tests/unittests/.gitignore
浏览文件 @
9dd64d83
...
...
@@ -3,3 +3,4 @@ mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
flowers.recordio
wmt16.recordio
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
9dd64d83
...
...
@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import
paddle.v2
as
paddle
import
paddle.v2.dataset.mnist
as
mnist
import
paddle.v2.dataset.flowers
as
flowers
import
paddle.v2.dataset.wmt16
as
wmt16
import
numpy
...
...
@@ -245,3 +246,161 @@ class TestResnet(TestParallelExecutorBase):
def
test_resnet
(
self
):
self
.
check_network_convergence
(
SE_ResNeXt152
,
iter
=
200
)
class
ModelHyperParams
(
object
):
# Dictionary size for source and target language. This model directly uses
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
# alreay been added, but the <pad> token is not added. Transformer requires
# sequences in a mini-batch are padded to have the same length. A <pad> token is
# added into the original dictionary in paddle.dateset.wmt16.
# size of source word dictionary.
src_vocab_size
=
10000
# index for <pad> token in source language.
src_pad_idx
=
src_vocab_size
# size of target word dictionay
trg_vocab_size
=
10000
# index for <pad> token in target language.
trg_pad_idx
=
trg_vocab_size
# position value corresponding to the <pad> token.
pos_pad_idx
=
0
# max length of sequences. It should plus 1 to include position
# padding token for position encoding.
max_length
=
50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
=
512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
=
1024
# the dimension that keys are projected to for dot-product attention.
d_key
=
64
# the dimension that values are projected to for dot-product attention.
d_value
=
64
# number of head used in multi-head attention.
n_head
=
8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer
=
6
# dropout rate used by all dropout layers.
dropout
=
0.1
import
numpy
as
np
def
prepare_batch_input
(
insts
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias. Then, convert the numpy
data to tensors and return a dict mapping names to tensors.
"""
def
__pad_batch_data
(
insts
,
pad_idx
,
is_target
=
False
,
return_pos
=
True
,
return_attn_bias
=
True
,
return_max_len
=
True
):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list
=
[]
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
inst_data
=
np
.
array
(
[
inst
+
[
pad_idx
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_data
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
return_pos
:
inst_pos
=
np
.
array
([[
pos_i
+
1
if
w_i
!=
pad_idx
else
0
for
pos_i
,
w_i
in
enumerate
(
inst
)
]
for
inst
in
inst_data
])
return_list
+=
[
inst_pos
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
return_attn_bias
:
if
is_target
:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data
=
np
.
ones
((
inst_data
.
shape
[
0
],
max_len
,
max_len
))
slf_attn_bias_data
=
np
.
triu
(
slf_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_len
,
max_len
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
,
[
1
,
n_head
,
1
,
1
])
*
[
-
1e9
]
else
:
# This is used to avoid attention on paddings.
slf_attn_bias_data
=
np
.
array
([[
0
]
*
len
(
inst
)
+
[
-
1e9
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
.
reshape
([
-
1
,
1
,
1
,
max_len
]),
[
1
,
n_head
,
max_len
,
1
])
return_list
+=
[
slf_attn_bias_data
.
astype
(
"float32"
)]
if
return_max_len
:
return_list
+=
[
max_len
]
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
def
data_to_tensor
(
data_list
,
name_list
,
input_dict
,
place
):
assert
len
(
data_list
)
==
len
(
name_list
)
for
i
in
range
(
len
(
name_list
)):
tensor
=
fluid
.
LoDTensor
()
tensor
.
set
(
data_list
[
i
],
place
)
input_dict
[
name_list
[
i
]]
=
tensor
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
__pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
is_target
=
False
)
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
__pad_batch_data
(
[
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
is_target
=
True
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
lbl_word
=
__pad_batch_data
([
inst
[
2
]
for
inst
in
insts
],
trg_pad_idx
,
False
,
False
,
False
,
False
)
lbl_weight
=
(
lbl_word
!=
trg_pad_idx
).
astype
(
"float32"
).
reshape
([
-
1
,
1
])
return
[
src_word
,
src_pos
,
trg_word
,
trg_pos
,
src_slf_attn_bias
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
]
import
transformer_model
def
transformer
():
return
transformer_model
.
transformer
(
ModelHyperParams
.
src_vocab_size
+
1
,
ModelHyperParams
.
trg_vocab_size
+
1
,
ModelHyperParams
.
max_length
+
1
,
ModelHyperParams
.
n_layer
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_key
,
ModelHyperParams
.
d_value
,
ModelHyperParams
.
d_model
,
ModelHyperParams
.
d_inner_hid
,
ModelHyperParams
.
dropout
,
ModelHyperParams
.
src_pad_idx
,
ModelHyperParams
.
trg_pad_idx
,
ModelHyperParams
.
pos_pad_idx
)
class
TestTransformer
(
TestParallelExecutorBase
):
@
classmethod
def
setUpClass
(
cls
):
reader
=
paddle
.
batch
(
wmt16
.
train
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
),
batch_size
=
transformer_model
.
batch_size
)
with
fluid
.
recordio_writer
.
create_recordio_writer
(
"./wmt16.recordio"
)
as
writer
:
for
batch
in
reader
():
for
tensor
in
prepare_batch_input
(
batch
,
ModelHyperParams
.
src_pad_idx
,
ModelHyperParams
.
trg_pad_idx
,
ModelHyperParams
.
n_head
):
t
=
fluid
.
LoDTensor
()
t
.
set
(
tensor
,
fluid
.
CPUPlace
())
writer
.
append_tensor
(
t
)
writer
.
complete_append_tensor
()
def
test_main
(
self
):
self
.
check_network_convergence
(
transformer
)
python/paddle/fluid/tests/unittests/transformer_model.py
0 → 100644
浏览文件 @
9dd64d83
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
pos_enc_param_names
=
(
"src_pos_enc_table"
,
"trg_pos_enc_table"
,
)
batch_size
=
64
def
position_encoding_init
(
n_position
,
d_pos_vec
):
"""
Generate the initial values for the sinusoid position encoding table.
"""
position_enc
=
np
.
array
([[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
d_pos_vec
)
for
j
in
range
(
d_pos_vec
)
]
if
pos
!=
0
else
np
.
zeros
(
d_pos_vec
)
for
pos
in
range
(
n_position
)])
position_enc
[
1
:,
0
::
2
]
=
np
.
sin
(
position_enc
[
1
:,
0
::
2
])
# dim 2i
position_enc
[
1
:,
1
::
2
]
=
np
.
cos
(
position_enc
[
1
:,
1
::
2
])
# dim 2i+1
return
position_enc
.
astype
(
"float32"
)
def
multi_head_attention
(
queries
,
keys
,
values
,
attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
=
1
,
dropout_rate
=
0.
):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
if
not
(
len
(
queries
.
shape
)
==
len
(
keys
.
shape
)
==
len
(
values
.
shape
)
==
3
):
raise
ValueError
(
"Inputs: quries, keys and values should all be 3-D tensors."
)
def
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
):
"""
Add linear projection to queries, keys, and values.
"""
q
=
layers
.
fc
(
input
=
queries
,
size
=
d_key
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
,
fan_in
=
d_model
*
d_key
,
fan_out
=
n_head
*
d_key
),
bias_attr
=
False
,
num_flatten_dims
=
2
)
k
=
layers
.
fc
(
input
=
keys
,
size
=
d_key
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
,
fan_in
=
d_model
*
d_key
,
fan_out
=
n_head
*
d_key
),
bias_attr
=
False
,
num_flatten_dims
=
2
)
v
=
layers
.
fc
(
input
=
values
,
size
=
d_value
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
,
fan_in
=
d_model
*
d_value
,
fan_out
=
n_head
*
d_value
),
bias_attr
=
False
,
num_flatten_dims
=
2
)
return
q
,
k
,
v
def
__split_heads
(
x
,
n_head
):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
if
n_head
==
1
:
return
x
hidden_size
=
x
.
shape
[
-
1
]
# FIXME(guosheng): Decouple the program desc with batch_size.
reshaped
=
layers
.
reshape
(
x
=
x
,
shape
=
[
batch_size
,
-
1
,
n_head
,
hidden_size
//
n_head
])
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return
layers
.
transpose
(
x
=
reshaped
,
perm
=
[
0
,
2
,
1
,
3
])
def
__combine_heads
(
x
):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if
len
(
x
.
shape
)
==
3
:
return
x
if
len
(
x
.
shape
)
!=
4
:
raise
ValueError
(
"Input(x) should be a 4-D Tensor."
)
trans_x
=
layers
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
,
3
])
# FIXME(guosheng): Decouple the program desc with batch_size.
return
layers
.
reshape
(
x
=
trans_x
,
shape
=
map
(
int
,
[
batch_size
,
-
1
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]]))
def
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_model
,
dropout_rate
):
"""
Scaled Dot-Product Attention
"""
# FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.
# The current implementation of softmax_op only supports 2D tensor,
# consequently it cannot be directly used here.
# If to use the reshape_op, Besides, the shape of product inferred in
# compile-time is not the actual shape in run-time. It cann't be used
# to set the attribute of reshape_op.
# So, here define the softmax for temporary solution.
def
__softmax
(
x
,
eps
=
1e-9
):
exp_out
=
layers
.
exp
(
x
=
x
)
sum_out
=
layers
.
reduce_sum
(
exp_out
,
dim
=-
1
,
keep_dim
=
False
)
return
layers
.
elementwise_div
(
x
=
exp_out
,
y
=
sum_out
,
axis
=
0
)
scaled_q
=
layers
.
scale
(
x
=
q
,
scale
=
d_model
**-
0.5
)
product
=
layers
.
matmul
(
x
=
scaled_q
,
y
=
k
,
transpose_y
=
True
)
weights
=
__softmax
(
layers
.
elementwise_add
(
x
=
product
,
y
=
attn_bias
))
if
dropout_rate
:
weights
=
layers
.
dropout
(
weights
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
out
=
layers
.
matmul
(
weights
,
v
)
return
out
q
,
k
,
v
=
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
)
q
=
__split_heads
(
q
,
n_head
)
k
=
__split_heads
(
k
,
n_head
)
v
=
__split_heads
(
v
,
n_head
)
ctx_multiheads
=
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_model
,
dropout_rate
)
out
=
__combine_heads
(
ctx_multiheads
)
# Project back to the model size.
proj_out
=
layers
.
fc
(
input
=
out
,
size
=
d_model
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
bias_attr
=
False
,
num_flatten_dims
=
2
)
return
proj_out
def
positionwise_feed_forward
(
x
,
d_inner_hid
,
d_hid
):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden
=
layers
.
fc
(
input
=
x
,
size
=
d_inner_hid
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
initializer
.
Uniform
(
low
=-
(
d_hid
**-
0.5
),
high
=
(
d_hid
**-
0.5
)),
act
=
"relu"
)
out
=
layers
.
fc
(
input
=
hidden
,
size
=
d_hid
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
initializer
.
Uniform
(
low
=-
(
d_inner_hid
**-
0.5
),
high
=
(
d_inner_hid
**-
0.5
)))
return
out
def
pre_post_process_layer
(
prev_out
,
out
,
process_cmd
,
dropout
=
0.
):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for
cmd
in
process_cmd
:
if
cmd
==
"a"
:
# add residual connection
out
=
out
+
prev_out
if
prev_out
else
out
elif
cmd
==
"n"
:
# add layer normalization
out
=
layers
.
layer_norm
(
out
,
begin_norm_axis
=
len
(
out
.
shape
)
-
1
,
param_attr
=
fluid
.
initializer
.
Constant
(
1.
),
bias_attr
=
fluid
.
initializer
.
Constant
(
0.
))
elif
cmd
==
"d"
:
# add dropout
if
dropout
:
out
=
layers
.
dropout
(
out
,
dropout_prob
=
dropout
,
is_test
=
False
)
return
out
pre_process_layer
=
partial
(
pre_post_process_layer
,
None
)
post_process_layer
=
pre_post_process_layer
def
prepare_encoder
(
src_word
,
src_pos
,
src_vocab_size
,
src_emb_dim
,
src_pad_idx
,
src_max_len
,
dropout
=
0.
,
pos_pad_idx
=
0
,
pos_enc_param_name
=
None
):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb
=
layers
.
embedding
(
src_word
,
size
=
[
src_vocab_size
,
src_emb_dim
],
padding_idx
=
src_pad_idx
,
param_attr
=
fluid
.
initializer
.
Normal
(
0.
,
1.
))
src_pos_enc
=
layers
.
embedding
(
src_pos
,
size
=
[
src_max_len
,
src_emb_dim
],
padding_idx
=
pos_pad_idx
,
param_attr
=
fluid
.
ParamAttr
(
name
=
pos_enc_param_name
,
trainable
=
False
))
enc_input
=
src_word_emb
+
src_pos_enc
# FIXME(guosheng): Decouple the program desc with batch_size.
enc_input
=
layers
.
reshape
(
x
=
enc_input
,
shape
=
[
batch_size
,
-
1
,
src_emb_dim
])
return
layers
.
dropout
(
enc_input
,
dropout_prob
=
dropout
,
is_test
=
False
)
if
dropout
else
enc_input
prepare_encoder
=
partial
(
prepare_encoder
,
pos_enc_param_name
=
pos_enc_param_names
[
0
])
prepare_decoder
=
partial
(
prepare_encoder
,
pos_enc_param_name
=
pos_enc_param_names
[
1
])
def
encoder_layer
(
enc_input
,
attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
=
0.
):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output
=
multi_head_attention
(
enc_input
,
enc_input
,
enc_input
,
attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
dropout_rate
)
attn_output
=
post_process_layer
(
enc_input
,
attn_output
,
"dan"
,
dropout_rate
)
ffd_output
=
positionwise_feed_forward
(
attn_output
,
d_inner_hid
,
d_model
)
return
post_process_layer
(
attn_output
,
ffd_output
,
"dan"
,
dropout_rate
)
def
encoder
(
enc_input
,
attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
=
0.
):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for
i
in
range
(
n_layer
):
enc_output
=
encoder_layer
(
enc_input
,
attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
)
enc_input
=
enc_output
return
enc_output
def
decoder_layer
(
dec_input
,
enc_output
,
slf_attn_bias
,
dec_enc_attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
=
0.
):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output
=
multi_head_attention
(
dec_input
,
dec_input
,
dec_input
,
slf_attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
dropout_rate
,
)
slf_attn_output
=
post_process_layer
(
dec_input
,
slf_attn_output
,
"dan"
,
# residual connection + dropout + layer normalization
dropout_rate
,
)
enc_attn_output
=
multi_head_attention
(
slf_attn_output
,
enc_output
,
enc_output
,
dec_enc_attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
dropout_rate
,
)
enc_attn_output
=
post_process_layer
(
slf_attn_output
,
enc_attn_output
,
"dan"
,
# residual connection + dropout + layer normalization
dropout_rate
,
)
ffd_output
=
positionwise_feed_forward
(
enc_attn_output
,
d_inner_hid
,
d_model
,
)
dec_output
=
post_process_layer
(
enc_attn_output
,
ffd_output
,
"dan"
,
# residual connection + dropout + layer normalization
dropout_rate
,
)
return
dec_output
def
decoder
(
dec_input
,
enc_output
,
dec_slf_attn_bias
,
dec_enc_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
=
0.
):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for
i
in
range
(
n_layer
):
dec_output
=
decoder_layer
(
dec_input
,
enc_output
,
dec_slf_attn_bias
,
dec_enc_attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
,
)
dec_input
=
dec_output
return
dec_output
def
transformer
(
src_vocab_size
,
trg_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
,
src_pad_idx
,
trg_pad_idx
,
pos_pad_idx
,
):
file_obj
=
fluid
.
layers
.
open_recordio_file
(
filename
=
'./wmt16.recordio'
,
shapes
=
[
[
batch_size
*
max_length
,
1
],
[
batch_size
*
max_length
,
1
],
[
batch_size
*
max_length
,
1
],
[
batch_size
*
max_length
,
1
],
[
batch_size
,
n_head
,
max_length
,
max_length
],
[
batch_size
,
n_head
,
max_length
,
max_length
],
[
batch_size
,
n_head
,
max_length
,
max_length
],
[
batch_size
*
max_length
,
1
],
[
batch_size
*
max_length
,
1
],
],
dtypes
=
[
'int64'
,
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'float32'
,
'float32'
,
'int64'
,
'float32'
,
],
lod_levels
=
[
0
]
*
9
)
src_word
,
src_pos
,
trg_word
,
trg_pos
,
src_slf_attn_bias
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
gold
,
weights
=
fluid
.
layers
.
read_file
(
file_obj
)
enc_input
=
prepare_encoder
(
src_word
,
src_pos
,
src_vocab_size
,
d_model
,
src_pad_idx
,
max_length
,
dropout_rate
,
)
enc_output
=
encoder
(
enc_input
,
src_slf_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
,
)
dec_input
=
prepare_decoder
(
trg_word
,
trg_pos
,
trg_vocab_size
,
d_model
,
trg_pad_idx
,
max_length
,
dropout_rate
,
)
dec_output
=
decoder
(
dec_input
,
enc_output
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
,
)
# TODO(guosheng): Share the weight matrix between the embedding layers and
# the pre-softmax linear transformation.
predict
=
layers
.
reshape
(
x
=
layers
.
fc
(
input
=
dec_output
,
size
=
trg_vocab_size
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
bias_attr
=
False
,
num_flatten_dims
=
2
),
shape
=
[
-
1
,
trg_vocab_size
],
act
=
"softmax"
)
cost
=
layers
.
cross_entropy
(
input
=
predict
,
label
=
gold
)
weighted_cost
=
cost
*
weights
return
layers
.
reduce_sum
(
weighted_cost
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录