Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
a6ec3a0d
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a6ec3a0d
编写于
6月 08, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Tune the Transformer model for wmt14
上级
f93838a4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
176 addition
and
120 deletion
+176
-120
fluid/neural_machine_translation/transformer/config.py
fluid/neural_machine_translation/transformer/config.py
+51
-32
fluid/neural_machine_translation/transformer/model.py
fluid/neural_machine_translation/transformer/model.py
+123
-88
fluid/neural_machine_translation/transformer/train.py
fluid/neural_machine_translation/transformer/train.py
+2
-0
未找到文件。
fluid/neural_machine_translation/transformer/config.py
浏览文件 @
a6ec3a0d
class
TrainTaskConfig
(
object
):
use_gpu
=
True
# the epoch number to train.
pass_num
=
3
0
pass_num
=
20
0
# the number of sequences contained in a mini-batch.
batch_size
=
32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
=
1
learning_rate
=
2
beta1
=
0.9
beta2
=
0.9
8
beta2
=
0.9
97
eps
=
1e-9
# the parameters for learning rate scheduling.
warmup_steps
=
4
000
warmup_steps
=
8
000
# the flag indicating to use average loss or sum loss when training.
use_avg_cost
=
True
# the weight used to mix up the ground-truth distribution and the fixed
...
...
@@ -33,12 +33,12 @@ class TrainTaskConfig(object):
class
InferTaskConfig
(
object
):
use_gpu
=
Tru
e
use_gpu
=
Fals
e
# the number of examples in one run for sequence generation.
batch_size
=
10
batch_size
=
2
# the parameters for beam search.
beam_size
=
5
max_
length
=
30
max_
out_len
=
30
# the number of decoded sentences to output.
n_best
=
1
# the flags indicating whether to output the special tokens.
...
...
@@ -55,26 +55,26 @@ class ModelHyperParams(object):
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# size of source word dictionary.
src_vocab_size
=
1
0000
src_vocab_size
=
5
0000
# size of target word dictionay
trg_vocab_size
=
1
0000
trg_vocab_size
=
5
0000
# index for <bos> token
bos_idx
=
0
bos_idx
=
1
# index for <eos> token
eos_idx
=
1
eos_idx
=
2
# index for <unk> token
unk_idx
=
2
unk_idx
=
0
# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length
=
50
max_length
=
256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
=
512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
=
1024
d_inner_hid
=
2048
# the dimension that keys are projected to for dot-product attention.
d_key
=
64
# the dimension that values are projected to for dot-product attention.
...
...
@@ -89,7 +89,7 @@ class ModelHyperParams(object):
def
merge_cfg_from_list
(
cfg_list
,
g_cfgs
):
"""
Set the above global configurations using the cfg_list.
Set the above global configurations using the cfg_list.
"""
assert
len
(
cfg_list
)
%
2
==
0
for
key
,
value
in
zip
(
cfg_list
[
0
::
2
],
cfg_list
[
1
::
2
]):
...
...
@@ -103,23 +103,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size
=
-
1
# The placeholder for squence length in compile time.
seq_len
=
ModelHyperParams
.
max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs
=
{
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"int64"
],
"src_word"
:
[(
batch_size
*
seq_len
,
1L
),
"int64"
,
2
],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
)
,
1L
),
"int64"
],
"src_pos"
:
[(
batch_size
*
seq_len
,
1L
),
"int64"
],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias"
:
[(
1
,
ModelHyperParams
.
n_head
,
(
ModelHyperParams
.
max_length
+
1
),
(
ModelHyperParams
.
max_length
+
1
)),
"float32"
],
"src_slf_attn_bias"
:
[(
batch_size
,
ModelHyperParams
.
n_head
,
seq_len
,
seq_len
),
"float32"
],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape"
:
[(
3L
,
),
"int32"
],
# This shape input is used to reshape before softmax in self attention.
...
...
@@ -128,24 +133,23 @@ input_descs = {
"src_slf_attn_post_softmax_shape"
:
[(
4L
,
),
"int32"
],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"int64"
],
"trg_word"
:
[(
batch_size
*
seq_len
,
1L
),
"int64"
,
2
],
# lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
)
,
1L
),
"int64"
],
"trg_pos"
:
[(
batch_size
*
seq_len
,
1L
),
"int64"
],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias"
:
[(
1
,
ModelHyperParams
.
n_head
,
(
ModelHyperParams
.
max_length
+
1
),
(
ModelHyperParams
.
max_length
+
1
)),
"float32"
],
"trg_slf_attn_bias"
:
[(
batch_size
,
ModelHyperParams
.
n_head
,
seq_len
,
seq_len
),
"float32"
],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias"
:
[(
1
,
ModelHyperParams
.
n_head
,
(
ModelHyperParams
.
max_length
+
1
),
(
ModelHyperParams
.
max_length
+
1
)),
"float32"
],
"trg_src_attn_bias"
:
[(
batch_size
,
ModelHyperParams
.
n_head
,
seq_len
,
seq_len
),
"float32"
],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape"
:
[(
3L
,
),
"int32"
],
# This shape input is used to reshape before softmax in self attention.
...
...
@@ -161,17 +165,23 @@ input_descs = {
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output"
:
[(
1
,
(
ModelHyperParams
.
max_length
+
1
),
ModelHyperParams
.
d_model
),
"float32"
],
"enc_output"
:
[(
batch_size
,
seq_len
,
ModelHyperParams
.
d_model
),
"float32"
],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
)
,
1L
),
"int64"
],
"lbl_word"
:
[(
batch_size
*
seq_len
,
1L
),
"int64"
],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"float32"
],
"lbl_weight"
:
[(
batch_size
*
seq_len
,
1L
),
"float32"
],
# These inputs are used to change the shape tensor in beam-search decoder.
"trg_slf_attn_pre_softmax_shape_delta"
:
[(
2L
,
),
"int32"
],
"trg_slf_attn_post_softmax_shape_delta"
:
[(
4L
,
),
"int32"
],
"init_score"
:
[(
batch_size
,
1L
),
"float32"
],
}
word_emb_param_names
=
(
"src_word_emb_table"
,
"trg_word_emb_table"
,
)
# Names of position encoding table which will be initialized externally.
pos_enc_param_names
=
(
"src_pos_enc_table"
,
...
...
@@ -200,3 +210,12 @@ decoder_util_input_fields = (
label_data_input_fields
=
(
"lbl_word"
,
"lbl_weight"
,
)
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields
=
(
"trg_word"
,
"init_score"
,
"trg_src_attn_bias"
,
)
fast_decoder_util_input_fields
=
decoder_util_input_fields
+
(
"trg_slf_attn_pre_softmax_shape_delta"
,
"trg_slf_attn_post_softmax_shape_delta"
,
)
fluid/neural_machine_translation/transformer/model.py
浏览文件 @
a6ec3a0d
...
...
@@ -6,6 +6,8 @@ import paddle.fluid.layers as layers
from
config
import
*
WEIGHT_SHARING
=
True
def
position_encoding_init
(
n_position
,
d_pos_vec
):
"""
...
...
@@ -30,7 +32,8 @@ def multi_head_attention(queries,
n_head
=
1
,
dropout_rate
=
0.
,
pre_softmax_shape
=
None
,
post_softmax_shape
=
None
):
post_softmax_shape
=
None
,
cache
=
None
):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
...
...
@@ -44,30 +47,30 @@ def multi_head_attention(queries,
"""
Add linear projection to queries, keys, and values.
"""
q
=
layers
.
fc
(
input
=
queries
,
size
=
d_key
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
,
fan_in
=
d_model
*
d_key
,
fan_out
=
n_head
*
d_key
),
bias_attr
=
False
,
num_flatten_dims
=
2
)
k
=
layers
.
fc
(
input
=
keys
,
size
=
d_key
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
,
fan_in
=
d_model
*
d_key
,
fan_out
=
n_head
*
d_key
),
bias_attr
=
False
,
num_flatten_dims
=
2
)
v
=
layers
.
fc
(
input
=
values
,
size
=
d_value
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
,
fan_in
=
d_model
*
d_value
,
fan_out
=
n_head
*
d_value
),
bias_attr
=
False
,
num_flatten_dims
=
2
)
q
=
layers
.
fc
(
input
=
queries
,
size
=
d_key
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
True
)
,
#
fan_in=d_model * d_key,
#
fan_out=n_head * d_key),
bias_attr
=
False
,
num_flatten_dims
=
2
)
k
=
layers
.
fc
(
input
=
keys
,
size
=
d_key
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
True
)
,
#
fan_in=d_model * d_key,
#
fan_out=n_head * d_key),
bias_attr
=
False
,
num_flatten_dims
=
2
)
v
=
layers
.
fc
(
input
=
values
,
size
=
d_value
*
n_head
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
True
)
,
#
fan_in=d_model * d_value,
#
fan_out=n_head * d_value),
bias_attr
=
False
,
num_flatten_dims
=
2
)
return
q
,
k
,
v
def
__split_heads
(
x
,
n_head
):
...
...
@@ -84,7 +87,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped
=
layers
.
reshape
(
x
=
x
,
shape
=
[
0
,
-
1
,
n_head
,
hidden_size
//
n_head
])
x
=
x
,
shape
=
[
0
,
0
,
n_head
,
hidden_size
//
n_head
])
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
...
...
@@ -104,13 +107,13 @@ def multi_head_attention(queries,
# size of the input as the output dimension size.
return
layers
.
reshape
(
x
=
trans_x
,
shape
=
map
(
int
,
[
0
,
-
1
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]]))
shape
=
map
(
int
,
[
0
,
0
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]]))
def
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_model
,
dropout_rate
):
"""
Scaled Dot-Product Attention
"""
scaled_q
=
layers
.
scale
(
x
=
q
,
scale
=
d_
model
**-
0.5
)
scaled_q
=
layers
.
scale
(
x
=
q
,
scale
=
d_
key
**-
0.5
)
product
=
layers
.
matmul
(
x
=
scaled_q
,
y
=
k
,
transpose_y
=
True
)
weights
=
layers
.
reshape
(
x
=
layers
.
elementwise_add
(
...
...
@@ -123,11 +126,15 @@ def multi_head_attention(queries,
if
dropout_rate
:
weights
=
layers
.
dropout
(
weights
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
out
=
layers
.
matmul
(
weights
,
v
)
return
out
q
,
k
,
v
=
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
)
if
cache
is
not
None
:
# use cache and concat time steps
k
=
cache
[
"k"
]
=
layers
.
concat
([
cache
[
"k"
],
k
],
axis
=
1
)
v
=
cache
[
"v"
]
=
layers
.
concat
([
cache
[
"v"
],
v
],
axis
=
1
)
q
=
__split_heads
(
q
,
n_head
)
k
=
__split_heads
(
k
,
n_head
)
v
=
__split_heads
(
v
,
n_head
)
...
...
@@ -136,7 +143,6 @@ def multi_head_attention(queries,
dropout_rate
)
out
=
__combine_heads
(
ctx_multiheads
)
# Project back to the model size.
proj_out
=
layers
.
fc
(
input
=
out
,
size
=
d_model
,
...
...
@@ -146,23 +152,32 @@ def multi_head_attention(queries,
return
proj_out
def
positionwise_feed_forward
(
x
,
d_inner_hid
,
d_hid
):
def
positionwise_feed_forward
(
x
,
d_inner_hid
,
d_hid
,
dropout_rate
=
0.
):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden
=
layers
.
fc
(
input
=
x
,
size
=
d_inner_hid
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
initializer
.
Uniform
(
low
=-
(
d_hid
**-
0.5
),
high
=
(
d_hid
**-
0.5
)),
act
=
"relu"
)
out
=
layers
.
fc
(
input
=
hidden
,
size
=
d_hid
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
initializer
.
Uniform
(
low
=-
(
d_inner_hid
**-
0.5
),
high
=
(
d_inner_hid
**-
0.5
)))
hidden
=
layers
.
fc
(
input
=
x
,
size
=
d_inner_hid
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
True
),
#param_attr=fluid.initializer.Uniform(
# low=-(d_hid**-0.5), high=(d_hid**-0.5)),
bias_attr
=
True
,
act
=
"relu"
)
if
dropout_rate
:
hidden
=
layers
.
dropout
(
hidden
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
out
=
layers
.
fc
(
input
=
hidden
,
size
=
d_hid
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
initializer
.
Xavier
(
uniform
=
True
),
#param_attr=fluid.initializer.Uniform(
# low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)),
bias_attr
=
True
)
return
out
...
...
@@ -200,6 +215,7 @@ def prepare_encoder(src_word,
src_max_len
,
dropout_rate
=
0.
,
src_data_shape
=
None
,
word_emb_param_name
=
None
,
pos_enc_param_name
=
None
):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
...
...
@@ -209,7 +225,10 @@ def prepare_encoder(src_word,
src_word_emb
=
layers
.
embedding
(
src_word
,
size
=
[
src_vocab_size
,
src_emb_dim
],
param_attr
=
fluid
.
initializer
.
Normal
(
0.
,
1.
))
param_attr
=
fluid
.
ParamAttr
(
name
=
word_emb_param_name
,
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
src_emb_dim
**-
0.5
)))
src_word_emb
=
layers
.
scale
(
x
=
src_word_emb
,
scale
=
src_emb_dim
**
0.5
)
src_pos_enc
=
layers
.
embedding
(
src_pos
,
size
=
[
src_max_len
,
src_emb_dim
],
...
...
@@ -218,7 +237,7 @@ def prepare_encoder(src_word,
enc_input
=
src_word_emb
+
src_pos_enc
enc_input
=
layers
.
reshape
(
x
=
enc_input
,
shape
=
[
-
1
,
src_max
_len
,
src_emb_dim
],
shape
=
[
batch_size
,
seq
_len
,
src_emb_dim
],
actual_shape
=
src_data_shape
)
return
layers
.
dropout
(
enc_input
,
dropout_prob
=
dropout_rate
,
...
...
@@ -226,9 +245,14 @@ def prepare_encoder(src_word,
prepare_encoder
=
partial
(
prepare_encoder
,
pos_enc_param_name
=
pos_enc_param_names
[
0
])
prepare_encoder
,
word_emb_param_name
=
word_emb_param_names
[
0
],
pos_enc_param_name
=
pos_enc_param_names
[
0
])
prepare_decoder
=
partial
(
prepare_encoder
,
pos_enc_param_name
=
pos_enc_param_names
[
1
])
prepare_encoder
,
word_emb_param_name
=
word_emb_param_names
[
0
]
if
WEIGHT_SHARING
else
word_emb_param_names
[
1
],
pos_enc_param_name
=
pos_enc_param_names
[
1
])
def
encoder_layer
(
enc_input
,
...
...
@@ -247,13 +271,14 @@ def encoder_layer(enc_input,
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output
=
multi_head_attention
(
enc_input
,
enc_input
,
enc_input
,
attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
dropout_rate
,
pre_softmax_shape
,
post_softmax_shape
)
attn_output
=
post_process_layer
(
enc_input
,
attn_output
,
"dan"
,
dropout_rate
)
ffd_output
=
positionwise_feed_forward
(
attn_output
,
d_inner_hid
,
d_model
)
return
post_process_layer
(
attn_output
,
ffd_output
,
"dan"
,
dropout_rate
)
q
=
k
=
v
=
pre_process_layer
(
enc_input
,
"n"
)
attn_output
=
multi_head_attention
(
q
,
k
,
v
,
attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
dropout_rate
,
pre_softmax_shape
,
post_softmax_shape
)
attn_output
=
post_process_layer
(
enc_input
,
attn_output
,
"da"
,
dropout_rate
)
ffd_output
=
positionwise_feed_forward
(
pre_process_layer
(
attn_output
,
"n"
),
d_inner_hid
,
d_model
,
dropout_rate
)
return
post_process_layer
(
attn_output
,
ffd_output
,
"da"
,
dropout_rate
)
def
encoder
(
enc_input
,
...
...
@@ -284,6 +309,7 @@ def encoder(enc_input,
pre_softmax_shape
,
post_softmax_shape
,
)
enc_input
=
enc_output
enc_output
=
pre_process_layer
(
enc_output
,
"n"
)
return
enc_output
...
...
@@ -300,15 +326,17 @@ def decoder_layer(dec_input,
slf_attn_pre_softmax_shape
=
None
,
slf_attn_post_softmax_shape
=
None
,
src_attn_pre_softmax_shape
=
None
,
src_attn_post_softmax_shape
=
None
):
src_attn_post_softmax_shape
=
None
,
cache
=
None
):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
q
=
k
=
v
=
pre_process_layer
(
dec_input
,
"n"
)
slf_attn_output
=
multi_head_attention
(
dec_input
,
dec_input
,
dec_input
,
q
,
k
,
v
,
slf_attn_bias
,
d_key
,
d_value
,
...
...
@@ -316,14 +344,15 @@ def decoder_layer(dec_input,
n_head
,
dropout_rate
,
slf_attn_pre_softmax_shape
,
slf_attn_post_softmax_shape
,
)
slf_attn_post_softmax_shape
,
cache
,
)
slf_attn_output
=
post_process_layer
(
dec_input
,
slf_attn_output
,
"da
n
"
,
# residual connection + dropout + layer normalization
"da"
,
# residual connection + dropout + layer normalization
dropout_rate
,
)
enc_attn_output
=
multi_head_attention
(
slf_attn_output
,
pre_process_layer
(
slf_attn_output
,
"n"
)
,
enc_output
,
enc_output
,
dec_enc_attn_bias
,
...
...
@@ -337,16 +366,17 @@ def decoder_layer(dec_input,
enc_attn_output
=
post_process_layer
(
slf_attn_output
,
enc_attn_output
,
"da
n
"
,
# residual connection + dropout + layer normalization
"da"
,
# residual connection + dropout + layer normalization
dropout_rate
,
)
ffd_output
=
positionwise_feed_forward
(
enc_attn_output
,
pre_process_layer
(
enc_attn_output
,
"n"
)
,
d_inner_hid
,
d_model
,
)
d_model
,
dropout_rate
,
)
dec_output
=
post_process_layer
(
enc_attn_output
,
ffd_output
,
"da
n
"
,
# residual connection + dropout + layer normalization
"da"
,
# residual connection + dropout + layer normalization
dropout_rate
,
)
return
dec_output
...
...
@@ -365,27 +395,20 @@ def decoder(dec_input,
slf_attn_pre_softmax_shape
=
None
,
slf_attn_post_softmax_shape
=
None
,
src_attn_pre_softmax_shape
=
None
,
src_attn_post_softmax_shape
=
None
):
src_attn_post_softmax_shape
=
None
,
caches
=
None
):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for
i
in
range
(
n_layer
):
dec_output
=
decoder_layer
(
dec_input
,
enc_output
,
dec_slf_attn_bias
,
dec_enc_attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
,
slf_attn_pre_softmax_shape
,
slf_attn_post_softmax_shape
,
src_attn_pre_softmax_shape
,
src_attn_post_softmax_shape
,
)
dec_input
,
enc_output
,
dec_slf_attn_bias
,
dec_enc_attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
,
slf_attn_pre_softmax_shape
,
slf_attn_post_softmax_shape
,
src_attn_pre_softmax_shape
,
src_attn_post_softmax_shape
,
None
if
caches
is
None
else
caches
[
i
])
dec_input
=
dec_output
dec_output
=
pre_process_layer
(
dec_output
,
"n"
)
return
dec_output
...
...
@@ -399,6 +422,8 @@ def make_all_inputs(input_fields):
name
=
input_field
,
shape
=
input_descs
[
input_field
][
0
],
dtype
=
input_descs
[
input_field
][
1
],
lod_level
=
input_descs
[
input_field
][
2
]
if
len
(
input_descs
[
input_field
])
==
3
else
0
,
append_batch_size
=
False
)
inputs
.
append
(
input_var
)
return
inputs
...
...
@@ -459,7 +484,6 @@ def transformer(
logits
=
predict
,
label
=
label
,
soft_label
=
True
if
label_smooth_eps
else
False
)
# cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
weighted_cost
=
cost
*
weights
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
token_num
=
layers
.
reduce_sum
(
weights
)
...
...
@@ -523,7 +547,8 @@ def wrap_decoder(trg_vocab_size,
d_inner_hid
,
dropout_rate
,
dec_inputs
=
None
,
enc_output
=
None
):
enc_output
=
None
,
caches
=
None
):
"""
The wrapper assembles together all needed layers for the decoder.
"""
...
...
@@ -563,13 +588,23 @@ def wrap_decoder(trg_vocab_size,
slf_attn_pre_softmax_shape
,
slf_attn_post_softmax_shape
,
src_attn_pre_softmax_shape
,
src_attn_post_softmax_shape
,
)
src_attn_post_softmax_shape
,
caches
,
)
# Return logits for training and probs for inference.
predict
=
layers
.
reshape
(
x
=
layers
.
fc
(
input
=
dec_output
,
size
=
trg_vocab_size
,
bias_attr
=
False
,
num_flatten_dims
=
2
),
shape
=
[
-
1
,
trg_vocab_size
],
act
=
"softmax"
if
dec_inputs
is
None
else
None
)
if
not
WEIGHT_SHARING
:
predict
=
layers
.
reshape
(
x
=
layers
.
fc
(
input
=
dec_output
,
size
=
trg_vocab_size
,
bias_attr
=
False
,
num_flatten_dims
=
2
),
shape
=
[
-
1
,
trg_vocab_size
],
act
=
"softmax"
if
dec_inputs
is
None
else
None
)
else
:
predict
=
layers
.
reshape
(
x
=
layers
.
matmul
(
x
=
dec_output
,
y
=
fluid
.
get_var
(
word_emb_param_names
[
0
]),
transpose_y
=
True
),
shape
=
[
-
1
,
trg_vocab_size
],
act
=
"softmax"
if
dec_inputs
is
None
else
None
)
return
predict
fluid/neural_machine_translation/transformer/train.py
浏览文件 @
a6ec3a0d
...
...
@@ -288,6 +288,7 @@ def train(args):
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
max_length
=
ModelHyperParams
.
max_length
,
clip_last_batch
=
False
)
train_data
=
read_multiple
(
reader
=
train_data
.
batch_generator
)
...
...
@@ -319,6 +320,7 @@ def train(args):
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
max_length
=
ModelHyperParams
.
max_length
,
clip_last_batch
=
False
,
shuffle
=
False
,
shuffle_batch
=
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录