Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
bca3c03d
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bca3c03d
编写于
5月 08, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add reader, ParallelExecutor and refine for Transformer
上级
e7684f07
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
849 addition
and
323 deletion
+849
-323
fluid/neural_machine_translation/transformer/config.py
fluid/neural_machine_translation/transformer/config.py
+115
-32
fluid/neural_machine_translation/transformer/infer.py
fluid/neural_machine_translation/transformer/infer.py
+69
-13
fluid/neural_machine_translation/transformer/model.py
fluid/neural_machine_translation/transformer/model.py
+32
-175
fluid/neural_machine_translation/transformer/optim.py
fluid/neural_machine_translation/transformer/optim.py
+4
-7
fluid/neural_machine_translation/transformer/reader.py
fluid/neural_machine_translation/transformer/reader.py
+341
-0
fluid/neural_machine_translation/transformer/train.py
fluid/neural_machine_translation/transformer/train.py
+288
-96
未找到文件。
fluid/neural_machine_translation/transformer/config.py
浏览文件 @
bca3c03d
class
TrainTaskConfig
(
object
):
use_gpu
=
Fals
e
use_gpu
=
Tru
e
# the epoch number to train.
pass_num
=
2
pass_num
=
30
# the number of sequences contained in a mini-batch.
batch_size
=
64
batch_size
=
32
# the hyper parameters for Adam optimizer.
learning_rate
=
0.001
# This static learning_rate will multiply LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
=
1
beta1
=
0.9
beta2
=
0.98
eps
=
1e-9
# the parameters for learning rate scheduling.
warmup_steps
=
4000
# the flag indicating to use average loss or sum loss when training.
use_avg_cost
=
False
use_avg_cost
=
True
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps
=
0.1
# the directory for saving trained models.
model_dir
=
"trained_models"
# the directory for saving checkpoints.
ckpt_dir
=
"trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path
=
None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step
=
0
class
InferTaskConfig
(
object
):
use_gpu
=
Fals
e
use_gpu
=
Tru
e
# the number of examples in one run for sequence generation.
batch_size
=
10
# the parameters for beam search.
beam_size
=
5
max_length
=
30
# the number of decoded sentences to output.
n_best
=
1
# the flags indicating whether to output the special tokens.
output_bos
=
False
output_eos
=
False
output_unk
=
False
# the directory for loading the trained model.
model_path
=
"trained_models/pass_1.infer.model"
...
...
@@ -47,30 +54,24 @@ class ModelHyperParams(object):
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# size of source word dictionary.
src_vocab_size
=
10000
# size of target word dictionay
trg_vocab_size
=
10000
# index for <bos> token
bos_idx
=
0
# index for <eos> token
eos_idx
=
1
# index for <unk> token
unk_idx
=
2
# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length
=
50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
=
512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
=
1024
...
...
@@ -86,34 +87,116 @@ class ModelHyperParams(object):
dropout
=
0.1
def
merge_cfg_from_list
(
cfg_list
,
g_cfgs
):
"""
Set the above global configurations using the cfg_list.
"""
assert
len
(
cfg_list
)
%
2
==
0
for
key
,
value
in
zip
(
cfg_list
[
0
::
2
],
cfg_list
[
1
::
2
]):
for
g_cfg
in
g_cfgs
:
if
hasattr
(
g_cfg
,
key
):
try
:
value
=
eval
(
value
)
except
SyntaxError
:
# for file path
pass
setattr
(
g_cfg
,
key
,
value
)
break
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs
=
{
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"int64"
],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"int64"
],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias"
:
[(
1
,
ModelHyperParams
.
n_head
,
(
ModelHyperParams
.
max_length
+
1
),
(
ModelHyperParams
.
max_length
+
1
)),
"float32"
],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape"
:
[(
3L
,
),
"int32"
],
# This shape input is used to reshape before softmax in self attention.
"src_slf_attn_pre_softmax_shape"
:
[(
2L
,
),
"int32"
],
# This shape input is used to reshape after softmax in self attention.
"src_slf_attn_post_softmax_shape"
:
[(
4L
,
),
"int32"
],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"int64"
],
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"int64"
],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias"
:
[(
1
,
ModelHyperParams
.
n_head
,
(
ModelHyperParams
.
max_length
+
1
),
(
ModelHyperParams
.
max_length
+
1
)),
"float32"
],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias"
:
[(
1
,
ModelHyperParams
.
n_head
,
(
ModelHyperParams
.
max_length
+
1
),
(
ModelHyperParams
.
max_length
+
1
)),
"float32"
],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape"
:
[(
3L
,
),
"int32"
],
# This shape input is used to reshape before softmax in self attention.
"trg_slf_attn_pre_softmax_shape"
:
[(
2L
,
),
"int32"
],
# This shape input is used to reshape after softmax in self attention.
"trg_slf_attn_post_softmax_shape"
:
[(
4L
,
),
"int32"
],
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
"trg_src_attn_pre_softmax_shape"
:
[(
2L
,
),
"int32"
],
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
"trg_src_attn_post_softmax_shape"
:
[(
4L
,
),
"int32"
],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output"
:
[(
1
,
(
ModelHyperParams
.
max_length
+
1
),
ModelHyperParams
.
d_model
),
"float32"
],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"int64"
],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight"
:
[(
1
*
(
ModelHyperParams
.
max_length
+
1
),
1L
),
"float32"
],
}
# Names of position encoding table which will be initialized externally.
pos_enc_param_names
=
(
"src_pos_enc_table"
,
"trg_pos_enc_table"
,
)
# Names of all data layers in encoder listed in order.
encoder_input_data_names
=
(
# separated inputs for different usages.
encoder_data_input_fields
=
(
"src_word"
,
"src_pos"
,
"src_slf_attn_bias"
,
"src_slf_attn_bias"
,
)
encoder_util_input_fields
=
(
"src_data_shape"
,
"src_slf_attn_pre_softmax_shape"
,
"src_slf_attn_post_softmax_shape"
,
)
# Names of all data layers in decoder listed in order.
decoder_input_data_names
=
(
decoder_data_input_fields
=
(
"trg_word"
,
"trg_pos"
,
"trg_slf_attn_bias"
,
"trg_src_attn_bias"
,
"enc_output"
,
)
decoder_util_input_fields
=
(
"trg_data_shape"
,
"trg_slf_attn_pre_softmax_shape"
,
"trg_slf_attn_post_softmax_shape"
,
"trg_src_attn_pre_softmax_shape"
,
"trg_src_attn_post_softmax_shape"
,
"enc_output"
,
)
# Names of label related data layers listed in order.
label_data_names
=
(
"trg_src_attn_post_softmax_shape"
,
)
label_data_input_fields
=
(
"lbl_word"
,
"lbl_weight"
,
)
fluid/neural_machine_translation/transformer/infer.py
浏览文件 @
bca3c03d
import
argparse
import
numpy
as
np
import
paddle
...
...
@@ -6,9 +7,52 @@ import paddle.fluid as fluid
import
model
from
model
import
wrap_encoder
as
encoder
from
model
import
wrap_decoder
as
decoder
from
config
import
InferTaskConfig
,
ModelHyperParams
,
\
encoder_input_data_names
,
decoder_input_data_names
from
config
import
*
from
train
import
pad_batch_data
import
reader
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Training for Transformer."
)
parser
.
add_argument
(
"--src_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of source language."
)
parser
.
add_argument
(
"--trg_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of target language."
)
parser
.
add_argument
(
"--test_file_pattern"
,
type
=
str
,
required
=
True
,
help
=
"The pattern to match test data files."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
50
,
help
=
"The number of examples in one run for sequence generation."
)
parser
.
add_argument
(
"--pool_size"
,
type
=
int
,
default
=
10000
,
help
=
"The buffer size to pool data."
)
parser
.
add_argument
(
"--special_token"
,
type
=
str
,
default
=
[
"<s>"
,
"<e>"
,
"<unk>"
],
nargs
=
3
,
help
=
"The <bos>, <eos> and <unk> tokens in the dictionary."
)
parser
.
add_argument
(
'opts'
,
help
=
'See config.py for all options'
,
default
=
None
,
nargs
=
argparse
.
REMAINDER
)
args
=
parser
.
parse_args
()
merge_cfg_from_list
(
args
.
opts
,
[
InferTaskConfig
,
ModelHyperParams
])
return
args
def
translate_batch
(
exe
,
...
...
@@ -243,7 +287,7 @@ def translate_batch(exe,
return
seqs
,
scores
[:,
:
n_best
].
tolist
()
def
main
(
):
def
infer
(
args
):
place
=
fluid
.
CUDAPlace
(
0
)
if
InferTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -292,13 +336,23 @@ def main():
decoder_program
=
fluid
.
io
.
get_inference_program
(
target_vars
=
[
predict
],
main_program
=
decoder_program
)
test_data
=
paddle
.
batch
(
paddle
.
dataset
.
wmt16
.
test
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
),
batch_size
=
InferTaskConfig
.
batch_size
)
test_data
=
reader
.
DataReader
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
fpattern
=
args
.
test_file_pattern
,
batch_size
=
args
.
batch_size
,
use_token_batch
=
False
,
pool_size
=
args
.
pool_size
,
sort_type
=
reader
.
SortType
.
NONE
,
shuffle
=
False
,
shuffle_batch
=
False
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
clip_last_batch
=
False
)
trg_idx2word
=
paddle
.
dataset
.
wmt16
.
get
_dict
(
"de"
,
dict_size
=
ModelHyperParams
.
trg_vocab_size
,
reverse
=
True
)
trg_idx2word
=
test_data
.
_load
_dict
(
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
def
post_process_seq
(
seq
,
bos_idx
=
ModelHyperParams
.
bos_idx
,
...
...
@@ -320,15 +374,16 @@ def main():
(
output_eos
or
idx
!=
eos_idx
),
seq
)
for
batch_id
,
data
in
enumerate
(
test_data
()):
for
batch_id
,
data
in
enumerate
(
test_data
.
batch_generator
()):
batch_seqs
,
batch_scores
=
translate_batch
(
exe
,
[
item
[
0
]
for
item
in
data
],
encoder_program
,
encoder_
input_data_name
s
,
encoder_
data_input_fields
+
encoder_util_input_field
s
,
[
enc_output
.
name
],
decoder_program
,
decoder_input_data_names
,
decoder_data_input_fields
[:
-
1
]
+
decoder_util_input_fields
+
(
decoder_data_input_fields
[
-
1
],
),
[
predict
.
name
],
InferTaskConfig
.
beam_size
,
InferTaskConfig
.
max_length
,
...
...
@@ -351,4 +406,5 @@ def main():
if
__name__
==
"__main__"
:
main
()
args
=
parse_args
()
infer
(
args
)
fluid/neural_machine_translation/transformer/model.py
浏览文件 @
bca3c03d
...
...
@@ -4,8 +4,7 @@ import numpy as np
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
from
config
import
TrainTaskConfig
,
pos_enc_param_names
,
\
encoder_input_data_names
,
decoder_input_data_names
,
label_data_names
from
config
import
*
def
position_encoding_init
(
n_position
,
d_pos_vec
):
...
...
@@ -171,7 +170,6 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
...
...
@@ -206,7 +204,6 @@ def prepare_encoder(src_word,
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb
=
layers
.
embedding
(
...
...
@@ -245,7 +242,6 @@ def encoder_layer(enc_input,
pre_softmax_shape
=
None
,
post_softmax_shape
=
None
):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
...
...
@@ -306,7 +302,6 @@ def decoder_layer(dec_input,
src_attn_pre_softmax_shape
=
None
,
src_attn_post_softmax_shape
=
None
):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
...
...
@@ -394,116 +389,19 @@ def decoder(dec_input,
return
dec_output
def
make_inputs
(
input_data_names
,
n_head
,
d_model
,
max_length
,
is_pos
,
slf_attn_bias_flag
,
src_attn_bias_flag
,
enc_output_flag
=
False
,
data_shape_flag
=
True
,
slf_attn_shape_flag
=
True
,
src_attn_shape_flag
=
True
):
def
make_all_inputs
(
input_fields
):
"""
Define the input data layers for the transformer model.
"""
input_layers
=
[]
batch_size
=
1
# Only for the infer-shape in compile time.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
# The actual data shape of word is:
# [batch_size * max_len_in_batch, 1]
word
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
batch_size
*
max_length
,
1
],
dtype
=
"int64"
,
append_batch_size
=
False
)
input_layers
+=
[
word
]
# This is used for position data or label weight.
# The actual data shape of pos is:
# [batch_size * max_len_in_batch, 1]
pos
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
batch_size
*
max_length
,
1
],
dtype
=
"int64"
if
is_pos
else
"float32"
,
append_batch_size
=
False
)
input_layers
+=
[
pos
]
if
slf_attn_bias_flag
:
# This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, max_len_in_batch, max_len_in_batch]
slf_attn_bias
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
batch_size
,
n_head
,
max_length
,
max_length
],
dtype
=
"float32"
,
inputs
=
[]
for
input_field
in
input_fields
:
input_var
=
layers
.
data
(
name
=
input_field
,
shape
=
input_descs
[
input_field
][
0
],
dtype
=
input_descs
[
input_field
][
1
],
append_batch_size
=
False
)
input_layers
+=
[
slf_attn_bias
]
if
src_attn_bias_flag
:
# This input is used to remove attention weights on paddings. It's used
# in encoder-decoder attention.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
src_attn_bias
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
batch_size
,
n_head
,
max_length
,
max_length
],
dtype
=
"float32"
,
append_batch_size
=
False
)
input_layers
+=
[
src_attn_bias
]
if
data_shape_flag
:
# This input is used to reshape the output of embedding layer.
data_shape
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
3
],
dtype
=
"int32"
,
append_batch_size
=
False
)
input_layers
+=
[
data_shape
]
if
slf_attn_shape_flag
:
# This shape input is used to reshape before softmax in self attention.
slf_attn_pre_softmax_shape
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
2
],
dtype
=
"int32"
,
append_batch_size
=
False
)
input_layers
+=
[
slf_attn_pre_softmax_shape
]
# This shape input is used to reshape after softmax in self attention.
slf_attn_post_softmax_shape
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
4
],
dtype
=
"int32"
,
append_batch_size
=
False
)
input_layers
+=
[
slf_attn_post_softmax_shape
]
if
src_attn_shape_flag
:
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
src_attn_pre_softmax_shape
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
2
],
dtype
=
"int32"
,
append_batch_size
=
False
)
input_layers
+=
[
src_attn_pre_softmax_shape
]
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
src_attn_post_softmax_shape
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
4
],
dtype
=
"int32"
,
append_batch_size
=
False
)
input_layers
+=
[
src_attn_post_softmax_shape
]
if
enc_output_flag
:
# This input is used in independent decoder program for inference.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, max_len_in_batch, d_model]
enc_output
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
batch_size
,
max_length
,
d_model
],
dtype
=
"float32"
,
append_batch_size
=
False
)
input_layers
+=
[
enc_output
]
return
input_layers
inputs
.
append
(
input_var
)
return
inputs
def
transformer
(
...
...
@@ -516,19 +414,10 @@ def transformer(
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
,
):
enc_inputs
=
make_inputs
(
encoder_input_data_names
,
n_head
,
d_model
,
max_length
,
is_pos
=
True
,
slf_attn_bias_flag
=
True
,
src_attn_bias_flag
=
False
,
enc_output_flag
=
False
,
data_shape_flag
=
True
,
slf_attn_shape_flag
=
True
,
src_attn_shape_flag
=
False
)
dropout_rate
,
label_smooth_eps
,
):
enc_inputs
=
make_all_inputs
(
encoder_data_input_fields
+
encoder_util_input_fields
)
enc_output
=
wrap_encoder
(
src_vocab_size
,
...
...
@@ -542,18 +431,8 @@ def transformer(
dropout_rate
,
enc_inputs
,
)
dec_inputs
=
make_inputs
(
decoder_input_data_names
,
n_head
,
d_model
,
max_length
,
is_pos
=
True
,
slf_attn_bias_flag
=
True
,
src_attn_bias_flag
=
True
,
enc_output_flag
=
False
,
data_shape_flag
=
True
,
slf_attn_shape_flag
=
True
,
src_attn_shape_flag
=
True
)
dec_inputs
=
make_all_inputs
(
decoder_data_input_fields
[:
-
1
]
+
decoder_util_input_fields
)
predict
=
wrap_decoder
(
trg_vocab_size
,
...
...
@@ -570,19 +449,17 @@ def transformer(
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold
,
weights
=
make_inputs
(
label_data_names
,
n_head
,
d_model
,
max_length
,
is_pos
=
False
,
slf_attn_bias_flag
=
False
,
src_attn_bias_flag
=
False
,
enc_output_flag
=
False
,
data_shape_flag
=
False
,
slf_attn_shape_flag
=
False
,
src_attn_shape_flag
=
False
)
cost
=
layers
.
softmax_with_cross_entropy
(
logits
=
predict
,
label
=
gold
)
label
,
weights
=
make_all_inputs
(
label_data_input_fields
)
if
label_smooth_eps
:
label
=
layers
.
label_smooth
(
label
=
layers
.
one_hot
(
input
=
label
,
depth
=
trg_vocab_size
),
epsilon
=
label_smooth_eps
)
cost
=
layers
.
softmax_with_cross_entropy
(
logits
=
predict
,
label
=
label
,
soft_label
=
True
if
label_smooth_eps
else
False
)
# cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
weighted_cost
=
cost
*
weights
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
token_num
=
layers
.
reduce_sum
(
weights
)
...
...
@@ -607,18 +484,8 @@ def wrap_encoder(src_vocab_size,
# This is used to implement independent encoder program in inference.
src_word
,
src_pos
,
src_slf_attn_bias
,
src_data_shape
,
\
slf_attn_pre_softmax_shape
,
slf_attn_post_softmax_shape
=
\
make_inputs
(
encoder_input_data_names
,
n_head
,
d_model
,
max_length
,
is_pos
=
True
,
slf_attn_bias_flag
=
True
,
src_attn_bias_flag
=
False
,
enc_output_flag
=
False
,
data_shape_flag
=
True
,
slf_attn_shape_flag
=
True
,
src_attn_shape_flag
=
False
)
make_all_inputs
(
encoder_data_input_fields
+
encoder_util_input_fields
)
else
:
src_word
,
src_pos
,
src_slf_attn_bias
,
src_data_shape
,
\
slf_attn_pre_softmax_shape
,
slf_attn_post_softmax_shape
=
\
...
...
@@ -663,20 +530,10 @@ def wrap_decoder(trg_vocab_size,
if
dec_inputs
is
None
:
# This is used to implement independent decoder program in inference.
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
\
trg_data_shape
,
slf_attn_pre_softmax_shape
,
\
enc_output
,
trg_data_shape
,
slf_attn_pre_softmax_shape
,
\
slf_attn_post_softmax_shape
,
src_attn_pre_softmax_shape
,
\
src_attn_post_softmax_shape
,
enc_output
=
make_inputs
(
decoder_input_data_names
,
n_head
,
d_model
,
max_length
,
is_pos
=
True
,
slf_attn_bias_flag
=
True
,
src_attn_bias_flag
=
True
,
enc_output_flag
=
True
,
data_shape_flag
=
True
,
slf_attn_shape_flag
=
True
,
src_attn_shape_flag
=
True
)
src_attn_post_softmax_shape
=
make_all_inputs
(
decoder_data_input_fields
+
decoder_util_input_fields
)
else
:
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
\
trg_data_shape
,
slf_attn_pre_softmax_shape
,
\
...
...
fluid/neural_machine_translation/transformer/optim.py
浏览文件 @
bca3c03d
...
...
@@ -14,27 +14,24 @@ class LearningRateScheduler(object):
def
__init__
(
self
,
d_model
,
warmup_steps
,
place
,
learning_rate
=
0.001
,
current_steps
=
0
,
name
=
"learning_rate"
):
self
.
current_steps
=
current_steps
self
.
warmup_steps
=
warmup_steps
self
.
d_model
=
d_model
self
.
static_lr
=
learning_rate
self
.
learning_rate
=
layers
.
create_global_var
(
name
=
name
,
shape
=
[
1
],
value
=
float
(
learning_rate
),
dtype
=
"float32"
,
persistable
=
True
)
self
.
place
=
place
def
update_learning_rate
(
self
,
data_input
):
def
update_learning_rate
(
self
):
self
.
current_steps
+=
1
lr_value
=
np
.
power
(
self
.
d_model
,
-
0.5
)
*
np
.
min
([
np
.
power
(
self
.
current_steps
,
-
0.5
),
np
.
power
(
self
.
warmup_steps
,
-
1.5
)
*
self
.
current_steps
])
lr_tensor
=
fluid
.
LoDTensor
()
lr_tensor
.
set
(
np
.
array
([
lr_value
],
dtype
=
"float32"
),
self
.
place
)
data_input
[
self
.
learning_rate
.
name
]
=
lr_tensor
])
*
self
.
static_lr
return
np
.
array
([
lr_value
],
dtype
=
"float32"
)
fluid/neural_machine_translation/transformer/reader.py
0 → 100644
浏览文件 @
bca3c03d
import
os
import
tarfile
import
glob
import
random
class
SortType
(
object
):
GLOBAL
=
'global'
POOL
=
'pool'
NONE
=
"none"
class
EndEpoch
():
pass
class
Pool
(
object
):
def
__init__
(
self
,
sample_generator
,
pool_size
,
sort
):
self
.
_pool_size
=
pool_size
self
.
_pool
=
[]
self
.
_sample_generator
=
sample_generator
()
self
.
_end
=
False
self
.
_sort
=
sort
def
_fill
(
self
):
while
len
(
self
.
_pool
)
<
self
.
_pool_size
and
not
self
.
_end
:
try
:
sample
=
self
.
_sample_generator
.
next
()
self
.
_pool
.
append
(
sample
)
except
StopIteration
as
e
:
self
.
_end
=
True
break
if
self
.
_sort
:
self
.
_pool
.
sort
(
key
=
lambda
sample
:
max
(
len
(
sample
[
0
]),
len
(
sample
[
1
]))
if
len
(
sample
)
>
1
else
len
(
sample
[
0
])
)
if
self
.
_end
and
len
(
self
.
_pool
)
<
self
.
_pool_size
:
self
.
_pool
.
append
(
EndEpoch
())
def
push_back
(
self
,
samples
):
if
len
(
self
.
_pool
)
!=
0
:
raise
Exception
(
"Pool should be empty."
)
if
len
(
samples
)
>=
self
.
_pool_size
:
raise
Exception
(
"Capacity of pool should be greater than a batch. "
"Please enlarge `pool_size`."
)
for
sample
in
samples
:
self
.
_pool
.
append
(
sample
)
self
.
_fill
()
def
next
(
self
,
look
=
False
):
if
len
(
self
.
_pool
)
==
0
:
return
None
else
:
return
self
.
_pool
[
0
]
if
look
else
self
.
_pool
.
pop
(
0
)
class
DataReader
(
object
):
"""
The data reader loads all data from files and produces batches of data
in the way corresponding to settings.
number of tokens or number of sequences.
"""
def
__init__
(
self
,
src_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
batch_size
,
pool_size
,
sort_type
=
SortType
.
NONE
,
clip_last_batch
=
True
,
tar_fname
=
None
,
min_length
=
0
,
max_length
=
100
,
shuffle
=
True
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
delimiter
=
"
\t
"
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
seed
=
0
):
"""
Load all data from files and set the settings to make mini-batches.
:param src_vocab_fpath: The path of vocabulary file of source language.
:type src_vocab_fpath: basestring
:param trg_vocab_fpath: The path of vocabulary file of target language.
:type trg_vocab_fpath: basestring
:param fpattern: The pattern to match data files.
:type fpattern: basestring
:param batch_size: The number of sequences contained in a mini-batch.
or the maximum number of tokens (include paddings) contained in a
mini-batch.
:type batch_size: int
:param pool_size: The buffer size to pool data.
:type pool_size: int
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param clip_last_batch: Whether to clip the last uncompleted batch.
:type clip_last_batch: bool
:param tar_fname: The data file in tar if fpattern matches a tar file.
:type tar_fname: basestring
:param min_length: The minimum length used to filt sequences.
:type min_length: int
:param max_length: The maximum length used to filt sequences.
:type max_length: int
:param shuffle: Whether to shuffle all instances.
:type shuffle: bool
:param shuffle_batch: Whether to shuffle the generated batches.
:type shuffle_batch: bool
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param delimiter: The delimiter used to split source and target in each
line of data file.
:type delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
:param end_mark: The token representing for the end of sentences
in dictionary.
:type end_mark: basestring
:param unk_mark: The token representing for unknown word in dictionary.
:type unk_mark: basestring
:param seed: The seed for random.
:type seed: int
"""
self
.
_src_vocab
=
self
.
_load_dict
(
src_vocab_fpath
)
self
.
_only_src
=
True
if
trg_vocab_fpath
is
not
None
:
self
.
_trg_vocab
=
self
.
_load_dict
(
trg_vocab_fpath
)
self
.
_only_src
=
False
self
.
_pool_size
=
pool_size
self
.
_batch_size
=
batch_size
self
.
_use_token_batch
=
use_token_batch
self
.
_sort_type
=
sort_type
self
.
_clip_last_batch
=
clip_last_batch
self
.
_shuffle
=
shuffle
self
.
_shuffle_batch
=
shuffle_batch
self
.
_min_length
=
min_length
self
.
_max_length
=
max_length
self
.
_delimiter
=
delimiter
self
.
_epoch_batches
=
[]
src_seq_words
,
trg_seq_words
=
self
.
_load_data
(
fpattern
,
tar_fname
)
self
.
_src_seq_ids
=
[[
self
.
_src_vocab
.
get
(
word
,
self
.
_src_vocab
.
get
(
unk_mark
))
for
word
in
([
start_mark
]
+
src_seq
+
[
end_mark
])
]
for
src_seq
in
src_seq_words
]
self
.
_sample_count
=
len
(
self
.
_src_seq_ids
)
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
=
[[
self
.
_trg_vocab
.
get
(
word
,
self
.
_trg_vocab
.
get
(
unk_mark
))
for
word
in
([
start_mark
]
+
trg_seq
+
[
end_mark
])
]
for
trg_seq
in
trg_seq_words
]
if
len
(
self
.
_trg_seq_ids
)
!=
self
.
_sample_count
:
raise
Exception
(
"Inconsistent sample count between "
"source sequences and target sequences."
)
else
:
self
.
_trg_seq_ids
=
None
self
.
_sample_idxs
=
[
i
for
i
in
xrange
(
self
.
_sample_count
)]
self
.
_sorted
=
False
random
.
seed
(
seed
)
def
_parse_file
(
self
,
f_obj
):
src_seq_words
=
[]
trg_seq_words
=
[]
for
line
in
f_obj
:
fields
=
line
.
strip
().
split
(
self
.
_delimiter
)
if
len
(
fields
)
!=
2
or
(
self
.
_only_src
and
len
(
fields
)
!=
1
):
continue
sample_words
=
[]
is_valid_sample
=
True
max_len
=
-
1
for
i
,
seq
in
enumerate
(
fields
):
seq_words
=
seq
.
split
()
max_len
=
max
(
max_len
,
len
(
seq_words
))
if
len
(
seq_words
)
==
0
or
\
len
(
seq_words
)
<
self
.
_min_length
or
\
len
(
seq_words
)
>
self
.
_max_length
or
\
(
self
.
_use_token_batch
and
max_len
>
self
.
_batch_size
):
is_valid_sample
=
False
break
sample_words
.
append
(
seq_words
)
if
not
is_valid_sample
:
continue
src_seq_words
.
append
(
sample_words
[
0
])
if
not
self
.
_only_src
:
trg_seq_words
.
append
(
sample_words
[
1
])
return
(
src_seq_words
,
trg_seq_words
)
def
_load_data
(
self
,
fpattern
,
tar_fname
):
fpaths
=
glob
.
glob
(
fpattern
)
src_seq_words
=
[]
trg_seq_words
=
[]
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
'r'
)
part_file_data
=
self
.
_parse_file
(
f
.
extractfile
(
tar_fname
))
src_seq_words
=
part_file_data
[
0
]
trg_seq_words
=
part_file_data
[
1
]
else
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
part_file_data
=
self
.
_parse_file
(
open
(
fpath
,
'r'
))
src_seq_words
.
extend
(
part_file_data
[
0
])
trg_seq_words
.
extend
(
part_file_data
[
1
])
return
src_seq_words
,
trg_seq_words
def
_load_dict
(
self
,
dict_path
,
reverse
=
False
):
word_dict
=
{}
with
open
(
dict_path
,
"r"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
()
else
:
word_dict
[
line
.
strip
()]
=
idx
return
word_dict
def
_sample_generator
(
self
):
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
not
self
.
_sorted
:
self
.
_sample_idxs
.
sort
(
key
=
lambda
idx
:
max
(
len
(
self
.
_src_seq_ids
[
idx
]),
len
(
self
.
_trg_seq_ids
[
idx
]
if
not
self
.
_only_src
else
0
))
)
self
.
_sorted
=
True
elif
self
.
_shuffle
:
random
.
shuffle
(
self
.
_sample_idxs
)
for
sample_idx
in
self
.
_sample_idxs
:
if
self
.
_only_src
:
yield
(
self
.
_src_seq_ids
[
sample_idx
])
else
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
self
.
_trg_seq_ids
[
sample_idx
][:
-
1
],
self
.
_trg_seq_ids
[
sample_idx
][
1
:])
def
batch_generator
(
self
):
pool
=
Pool
(
self
.
_sample_generator
,
self
.
_pool_size
,
True
if
self
.
_sort_type
==
SortType
.
POOL
else
False
)
def
next_batch
():
batch_data
=
[]
max_len
=
-
1
batch_max_seq_len
=
-
1
while
True
:
sample
=
pool
.
next
(
look
=
True
)
if
sample
is
None
:
pool
.
push_back
(
batch_data
)
batch_data
=
[]
continue
if
isinstance
(
sample
,
EndEpoch
):
return
batch_data
,
batch_max_seq_len
,
True
max_len
=
max
(
max_len
,
len
(
sample
[
0
]))
if
not
self
.
_only_src
:
max_len
=
max
(
max_len
,
len
(
sample
[
1
]))
if
self
.
_use_token_batch
:
if
max_len
*
(
len
(
batch_data
)
+
1
)
<
self
.
_batch_size
:
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
else
:
if
len
(
batch_data
)
<
self
.
_batch_size
:
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
if
not
self
.
_shuffle_batch
:
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
while
not
last_batch
:
yield
batch_data
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
yield
batch_data
else
:
# should re-generate batches
if
self
.
_sort_type
==
SortType
.
POOL
\
or
len
(
self
.
_epoch_batches
)
==
0
:
self
.
_epoch_batches
=
[]
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
while
not
last_batch
:
self
.
_epoch_batches
.
append
(
batch_data
)
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
self
.
_epoch_batches
.
append
(
batch_data
)
random
.
shuffle
(
self
.
_epoch_batches
)
for
batch_data
in
self
.
_epoch_batches
:
yield
batch_data
fluid/neural_machine_translation/transformer/train.py
浏览文件 @
bca3c03d
import
os
import
time
import
argparse
import
ast
import
numpy
as
np
import
paddle
...
...
@@ -7,8 +9,78 @@ import paddle.fluid as fluid
from
model
import
transformer
,
position_encoding_init
from
optim
import
LearningRateScheduler
from
config
import
TrainTaskConfig
,
ModelHyperParams
,
pos_enc_param_names
,
\
encoder_input_data_names
,
decoder_input_data_names
,
label_data_names
from
config
import
*
import
reader
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Training for Transformer."
)
parser
.
add_argument
(
"--src_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of source language."
)
parser
.
add_argument
(
"--trg_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of target language."
)
parser
.
add_argument
(
"--train_file_pattern"
,
type
=
str
,
required
=
True
,
help
=
"The pattern to match training data files."
)
parser
.
add_argument
(
"--val_file_pattern"
,
type
=
str
,
help
=
"The pattern to match validation data files."
)
parser
.
add_argument
(
"--use_token_batch"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"The flag indicating whether to "
"produce batch data according to token number."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
2000
,
help
=
"The number of sequences contained in a mini-batch, or the maximum "
"number of tokens (include paddings) contained in a mini-batch."
)
parser
.
add_argument
(
"--pool_size"
,
type
=
int
,
default
=
10000
,
help
=
"The buffer size to pool data."
)
parser
.
add_argument
(
"--sort_type"
,
default
=
"pool"
,
choices
=
(
"global"
,
"pool"
,
"none"
),
help
=
"The grain to sort by length: global for all instances; pool for "
"instances in pool; none for no sort."
)
parser
.
add_argument
(
"--shuffle"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"The flag indicating whether to shuffle instances in each pass."
)
parser
.
add_argument
(
"--shuffle_batch"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"The flag indicating whether to shuffle the data batches."
)
parser
.
add_argument
(
"--special_token"
,
type
=
str
,
default
=
[
"<s>"
,
"<e>"
,
"<unk>"
],
nargs
=
3
,
help
=
"The <bos>, <eos> and <unk> tokens in the dictionary."
)
parser
.
add_argument
(
'opts'
,
help
=
'See config.py for all options'
,
default
=
None
,
nargs
=
argparse
.
REMAINDER
)
args
=
parser
.
parse_args
()
merge_cfg_from_list
(
args
.
opts
,
[
TrainTaskConfig
,
ModelHyperParams
])
return
args
def
pad_batch_data
(
insts
,
...
...
@@ -17,13 +89,16 @@ def pad_batch_data(insts,
is_target
=
False
,
is_label
=
False
,
return_attn_bias
=
True
,
return_max_len
=
True
):
return_max_len
=
True
,
return_num_token
=
False
):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list
=
[]
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
num_token
=
reduce
(
lambda
x
,
y
:
x
+
y
,
[
len
(
inst
)
for
inst
in
insts
])
if
return_num_token
else
0
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data
=
np
.
array
(
...
...
@@ -44,8 +119,8 @@ def pad_batch_data(insts,
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data
=
np
.
ones
((
inst_data
.
shape
[
0
],
max_len
,
max_len
))
slf_attn_bias_data
=
np
.
triu
(
slf_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_len
,
max_len
])
slf_attn_bias_data
=
np
.
triu
(
slf_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_len
,
max_len
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
,
[
1
,
n_head
,
1
,
1
])
*
[
-
1e9
]
else
:
...
...
@@ -59,11 +134,13 @@ def pad_batch_data(insts,
return_list
+=
[
slf_attn_bias_data
.
astype
(
"float32"
)]
if
return_max_len
:
return_list
+=
[
max_len
]
if
return_num_token
:
return_list
+=
[
num_token
]
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
def
prepare_batch_input
(
insts
,
input_data_names
,
src_pad_idx
,
trg
_pad_idx
,
n_head
,
d_model
):
def
prepare_batch_input
(
insts
,
data_input_names
,
util_input_names
,
src
_pad_idx
,
trg_pad_idx
,
n_head
,
d_model
):
"""
Put all padded data needed by training into a dict.
"""
...
...
@@ -75,139 +152,254 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
# These shape tensors are used in reshape_op.
src_data_shape
=
np
.
array
([
len
(
insts
)
,
src_max_len
,
d_model
],
dtype
=
"int32"
)
trg_data_shape
=
np
.
array
([
len
(
insts
)
,
trg_max_len
,
d_model
],
dtype
=
"int32"
)
src_data_shape
=
np
.
array
([
-
1
,
src_max_len
,
d_model
],
dtype
=
"int32"
)
trg_data_shape
=
np
.
array
([
-
1
,
trg_max_len
,
d_model
],
dtype
=
"int32"
)
src_slf_attn_pre_softmax_shape
=
np
.
array
(
[
-
1
,
src_slf_attn_bias
.
shape
[
-
1
]],
dtype
=
"int32"
)
src_slf_attn_post_softmax_shape
=
np
.
array
(
src_slf_attn_bias
.
shape
,
dtype
=
"int32"
)
[
-
1
]
+
list
(
src_slf_attn_bias
.
shape
[
1
:])
,
dtype
=
"int32"
)
trg_slf_attn_pre_softmax_shape
=
np
.
array
(
[
-
1
,
trg_slf_attn_bias
.
shape
[
-
1
]],
dtype
=
"int32"
)
trg_slf_attn_post_softmax_shape
=
np
.
array
(
trg_slf_attn_bias
.
shape
,
dtype
=
"int32"
)
[
-
1
]
+
list
(
trg_slf_attn_bias
.
shape
[
1
:])
,
dtype
=
"int32"
)
trg_src_attn_pre_softmax_shape
=
np
.
array
(
[
-
1
,
trg_src_attn_bias
.
shape
[
-
1
]],
dtype
=
"int32"
)
trg_src_attn_post_softmax_shape
=
np
.
array
(
trg_src_attn_bias
.
shape
,
dtype
=
"int32"
)
[
-
1
]
+
list
(
trg_src_attn_bias
.
shape
[
1
:])
,
dtype
=
"int32"
)
lbl_word
,
lbl_weight
=
pad_batch_data
(
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
[
inst
[
2
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
True
,
return_attn_bias
=
False
,
return_max_len
=
False
)
input_dict
=
dict
(
zip
(
input_data_names
,
[
src_word
,
src_pos
,
src_slf_attn_bias
,
src_data_shape
,
src_slf_attn_pre_softmax_shape
,
src_slf_attn_post_softmax_shape
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
trg_data_shape
,
trg_slf_attn_pre_softmax_shape
,
trg_slf_attn_post_softmax_shape
,
trg_src_attn_pre_softmax_shape
,
trg_src_attn_post_softmax_shape
,
lbl_word
,
lbl_weight
return_max_len
=
False
,
return_num_token
=
True
)
data_input_dict
=
dict
(
zip
(
data_input_names
,
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
]))
util_input_dict
=
dict
(
zip
(
util_input_names
,
[
src_data_shape
,
src_slf_attn_pre_softmax_shape
,
src_slf_attn_post_softmax_shape
,
trg_data_shape
,
trg_slf_attn_pre_softmax_shape
,
trg_slf_attn_post_softmax_shape
,
trg_src_attn_pre_softmax_shape
,
trg_src_attn_post_softmax_shape
]))
return
input_dict
return
data_input_dict
,
util_input_dict
,
np
.
asarray
(
[
num_token
],
dtype
=
"float32"
)
def
main
():
place
=
fluid
.
CUDAPlace
(
0
)
if
TrainTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
def
train
(
args
):
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
def
read_multiple
(
reader
,
count
=
dev_count
if
args
.
use_token_batch
else
1
,
clip_last
=
False
):
"""
Stack data from reader for multi-devices.
"""
def
__impl__
():
res
=
[]
for
item
in
reader
():
res
.
append
(
item
)
if
len
(
res
)
==
count
:
yield
res
res
=
[]
if
len
(
res
)
==
count
:
yield
res
elif
not
clip_last
:
data
=
[]
for
item
in
res
:
data
+=
item
if
len
(
data
)
>
count
:
inst_num_per_part
=
len
(
data
)
//
count
yield
[
data
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
1
)]
for
i
in
range
(
count
)
]
return
__impl__
def
split_data
(
data
,
num_part
=
dev_count
):
"""
Split data for each device.
"""
if
len
(
data
)
==
num_part
:
return
data
data
=
data
[
0
]
inst_num_per_part
=
len
(
data
)
//
num_part
return
[
data
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
1
)]
for
i
in
range
(
num_part
)
]
sum_cost
,
avg_cost
,
predict
,
token_num
=
transformer
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
,
ModelHyperParams
.
max_length
+
1
,
ModelHyperParams
.
n_layer
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_key
,
ModelHyperParams
.
d_value
,
ModelHyperParams
.
d_model
,
ModelHyperParams
.
d_inner_hid
,
ModelHyperParams
.
dropout
)
ModelHyperParams
.
d_inner_hid
,
ModelHyperParams
.
dropout
,
TrainTaskConfig
.
label_smooth_eps
)
lr_scheduler
=
LearningRateScheduler
(
ModelHyperParams
.
d_model
,
TrainTaskConfig
.
warmup_steps
,
place
,
TrainTaskConfig
.
warmup_steps
,
TrainTaskConfig
.
learning_rate
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
lr_scheduler
.
learning_rate
,
beta1
=
TrainTaskConfig
.
beta1
,
beta2
=
TrainTaskConfig
.
beta2
,
epsilon
=
TrainTaskConfig
.
eps
)
optimizer
.
minimize
(
avg_cost
if
TrainTaskConfig
.
use_avg_cost
else
sum_cost
)
optimizer
.
minimize
(
sum_cost
)
place
=
fluid
.
CUDAPlace
(
0
)
if
TrainTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Initialize the parameters.
if
TrainTaskConfig
.
ckpt_path
:
fluid
.
io
.
load_persistables
(
exe
,
TrainTaskConfig
.
ckpt_path
)
lr_scheduler
.
current_steps
=
TrainTaskConfig
.
start_step
else
:
exe
.
run
(
fluid
.
framework
.
default_startup_program
())
train_data
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
wmt16
.
train
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
),
buf_size
=
100000
),
batch_size
=
TrainTaskConfig
.
batch_size
)
train_data
=
reader
.
DataReader
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
fpattern
=
args
.
train_file_pattern
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
*
(
1
if
args
.
use_token_batch
else
dev_count
),
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
clip_last_batch
=
False
)
# Program to do validation.
train_data
=
read_multiple
(
reader
=
train_data
.
batch_generator
)
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
TrainTaskConfig
.
use_gpu
,
loss_name
=
sum_cost
.
name
,
use_default_grad_scale
=
False
)
def
test_context
():
# Context to do validation.
test_program
=
fluid
.
default_main_program
().
clone
()
with
fluid
.
program_guard
(
test_program
):
test_program
=
fluid
.
io
.
get_inference_program
([
avg_cost
])
val_data
=
paddle
.
batch
(
paddle
.
dataset
.
wmt16
.
validation
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
),
batch_size
=
TrainTaskConfig
.
batch_size
)
def
test
(
exe
):
val_data
=
reader
.
DataReader
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
fpattern
=
args
.
val_file_pattern
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
*
(
1
if
args
.
use_token_batch
else
dev_count
),
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
clip_last_batch
=
False
,
shuffle
=
False
,
shuffle_batch
=
False
)
test_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
TrainTaskConfig
.
use_gpu
,
main_program
=
test_program
,
share_vars_from
=
train_exe
)
def
test
(
exe
=
test_exe
):
test_total_cost
=
0
test_total_token
=
0
for
batch_id
,
data
in
enumerate
(
val_data
()):
data_input
=
prepare_batch_input
(
data
,
encoder_input_data_names
+
decoder_input_data_names
[:
-
1
]
+
label_data_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
test_sum_cost
,
test_token_num
=
exe
.
run
(
test_program
,
feed
=
data_input
,
fetch_list
=
[
sum_cost
,
token_num
],
use_program_cache
=
True
)
test_total_cost
+=
test_sum_cost
test_total_token
+=
test_token_num
test_data
=
read_multiple
(
reader
=
val_data
.
batch_generator
)
for
batch_id
,
data
in
enumerate
(
test_data
()):
feed_list
=
[]
for
place_id
,
data_buffer
in
enumerate
(
split_data
(
data
)):
data_input_dict
,
util_input_dict
,
_
=
prepare_batch_input
(
data_buffer
,
data_input_names
,
util_input_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
feed_list
.
append
(
dict
(
data_input_dict
.
items
()
+
util_input_dict
.
items
()))
outs
=
exe
.
run
(
feed
=
feed_list
,
fetch_list
=
[
sum_cost
.
name
,
token_num
.
name
])
sum_cost_val
,
token_num_val
=
np
.
array
(
outs
[
0
]),
np
.
array
(
outs
[
1
])
test_total_cost
+=
sum_cost_val
.
sum
()
test_total_token
+=
token_num_val
.
sum
()
test_avg_cost
=
test_total_cost
/
test_total_token
test_ppl
=
np
.
exp
([
min
(
test_avg_cost
,
100
)])
return
test_avg_cost
,
test_ppl
# Initialize the parameters.
exe
.
run
(
fluid
.
framework
.
default_startup_program
())
for
pos_enc_param_name
in
pos_enc_param_names
:
pos_enc_param
=
fluid
.
global_scope
().
find_var
(
pos_enc_param_name
).
get_tensor
()
pos_enc_param
.
set
(
position_encoding_init
(
ModelHyperParams
.
max_length
+
1
,
ModelHyperParams
.
d_model
),
place
)
return
test
if
args
.
val_file_pattern
is
not
None
:
test
=
test_context
()
data_input_names
=
encoder_data_input_fields
+
decoder_data_input_fields
[:
-
1
]
+
label_data_input_fields
util_input_names
=
encoder_util_input_fields
+
decoder_util_input_fields
init
=
False
for
pass_id
in
xrange
(
TrainTaskConfig
.
pass_num
):
pass_start_time
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_data
()):
if
len
(
data
)
!=
TrainTaskConfig
.
batch_size
:
continue
data_input
=
prepare_batch_input
(
data
,
encoder_input_data_names
+
decoder_input_data_names
[:
-
1
]
+
label_data_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
feed_list
=
[]
total_num_token
=
0
lr_rate
=
lr_scheduler
.
update_learning_rate
()
for
place_id
,
data_buffer
in
enumerate
(
split_data
(
data
)):
data_input_dict
,
util_input_dict
,
num_token
=
prepare_batch_input
(
data_buffer
,
data_input_names
,
util_input_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
total_num_token
+=
num_token
feed_list
.
append
(
dict
(
data_input_dict
.
items
()
+
util_input_dict
.
items
()
+
{
lr_scheduler
.
learning_rate
.
name
:
lr_rate
}.
items
()))
if
not
init
:
for
pos_enc_param_name
in
pos_enc_param_names
:
pos_enc
=
position_encoding_init
(
ModelHyperParams
.
max_length
+
1
,
ModelHyperParams
.
d_model
)
lr_scheduler
.
update_learning_rate
(
data_input
)
outs
=
exe
.
run
(
fluid
.
framework
.
default_main_program
(),
feed
=
data_input
,
fetch_list
=
[
sum_cost
,
avg_cost
],
use_program_cache
=
True
)
sum_cost_val
,
avg_cost_val
=
np
.
array
(
outs
[
0
]),
np
.
array
(
outs
[
1
])
feed_list
[
place_id
][
pos_enc_param_name
]
=
pos_enc
for
feed_dict
in
feed_list
:
feed_dict
[
sum_cost
.
name
+
"@GRAD"
]
=
1.
/
total_num_token
if
TrainTaskConfig
.
use_avg_cost
else
np
.
asarray
(
[
1.
],
dtype
=
"float32"
)
outs
=
train_exe
.
run
(
fetch_list
=
[
sum_cost
.
name
,
token_num
.
name
],
feed
=
feed_list
)
sum_cost_val
,
token_num_val
=
np
.
array
(
outs
[
0
]),
np
.
array
(
outs
[
1
])
total_sum_cost
=
sum_cost_val
.
sum
(
)
# sum the cost from multi-devices
total_token_num
=
token_num_val
.
sum
()
total_avg_cost
=
total_sum_cost
/
total_token_num
print
(
"epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f"
%
(
pass_id
,
batch_id
,
sum_cost_val
,
avg_cost_val
,
np
.
exp
([
min
(
avg_cost_val
[
0
],
100
)])))
(
pass_id
,
batch_id
,
total_sum_cost
,
total_avg_cost
,
np
.
exp
([
min
(
total_avg_cost
,
100
)])))
init
=
True
# Validate and save the model for inference.
val_avg_cost
,
val_ppl
=
test
(
exe
)
pass_end_time
=
time
.
time
()
time_consumed
=
pass_end_time
-
pass_start_time
print
(
"epoch: %d, val avg loss: %f, val ppl: %f, "
"consumed %fs"
%
(
pass_id
,
val_avg_cost
,
val_ppl
,
time_consumed
))
print
(
"epoch: %d, "
%
pass_id
+
(
"val avg loss: %f, val ppl: %f, "
%
test
()
if
args
.
val_file_pattern
is
not
None
else
""
)
+
"consumed %fs"
%
(
time
.
time
()
-
pass_start_time
))
fluid
.
io
.
save_persistables
(
exe
,
os
.
path
.
join
(
TrainTaskConfig
.
ckpt_dir
,
"pass_"
+
str
(
pass_id
)
+
".checkpoint"
))
fluid
.
io
.
save_inference_model
(
os
.
path
.
join
(
TrainTaskConfig
.
model_dir
,
"pass_"
+
str
(
pass_id
)
+
".infer.model"
),
encoder_input_data_names
+
decoder_input_data_names
[:
-
1
],
[
predict
],
exe
)
data_input_names
[:
-
2
]
+
util_input_names
,
[
predict
],
exe
)
if
__name__
==
"__main__"
:
main
()
args
=
parse_args
()
train
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录