Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
6ef54e8e
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
223
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6ef54e8e
编写于
3月 20, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine Transformer by following comments and fix the target self attention bias in inference.
上级
ff80721e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
69 addition
and
39 deletion
+69
-39
fluid/neural_machine_translation/transformer/config.py
fluid/neural_machine_translation/transformer/config.py
+11
-9
fluid/neural_machine_translation/transformer/infer.py
fluid/neural_machine_translation/transformer/infer.py
+24
-10
fluid/neural_machine_translation/transformer/model.py
fluid/neural_machine_translation/transformer/model.py
+34
-20
未找到文件。
fluid/neural_machine_translation/transformer/config.py
浏览文件 @
6ef54e8e
...
...
@@ -3,34 +3,36 @@ class TrainTaskConfig(object):
# the epoch number to train.
pass_num
=
2
# number of sequences contained in a mini-batch.
#
the
number of sequences contained in a mini-batch.
batch_size
=
64
# the hyper params for Adam optimizer.
# the hyper param
eter
s for Adam optimizer.
learning_rate
=
0.001
beta1
=
0.9
beta2
=
0.98
eps
=
1e-9
# the param
s for learning rate scheduling
# the param
eters for learning rate scheduling.
warmup_steps
=
4000
# the directory for saving
inference models
model_dir
=
"tra
nsformer_model
"
# the directory for saving
trained models.
model_dir
=
"tra
ined_models
"
class
InferTaskConfig
(
object
):
use_gpu
=
False
# number of sequences contained in a mini-batch
# the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1.
batch_size
=
1
# the param
s for beam search
# the param
eters for beam search.
beam_size
=
5
max_length
=
30
# the number of decoded sentences to output.
n_best
=
1
# the directory for loading
inference model
model_path
=
"tra
nsformer_model
/pass_1.infer.model"
# the directory for loading
the trained model.
model_path
=
"tra
ined_models
/pass_1.infer.model"
class
ModelHyperParams
(
object
):
...
...
fluid/neural_machine_translation/transformer/infer.py
浏览文件 @
6ef54e8e
...
...
@@ -66,12 +66,19 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_pos
=
np
.
array
([[
1
]]
*
batch_size
*
beam_size
,
dtype
=
"int64"
)
src_max_length
,
src_slf_attn_bias
,
trg_max_len
=
enc_in_data
[
-
1
],
enc_in_data
[
-
2
],
1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias
=
np
.
ones
((
batch_size
*
beam_size
,
trg_max_len
,
trg_max_len
))
trg_slf_attn_bias
=
np
.
triu
(
trg_slf_attn_bias
,
1
).
reshape
(
[
-
1
,
1
,
trg_max_len
,
trg_max_len
])
trg_slf_attn_bias
=
(
np
.
tile
(
trg_slf_attn_bias
,
[
1
,
n_head
,
1
,
1
])
*
[
-
1e9
]).
astype
(
"float32"
)
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_length
,
:],
[
beam_size
,
1
,
trg_max_len
,
1
])
enc_output
=
np
.
tile
(
enc_output
,
[
beam_size
,
1
,
1
])
# No need for trg_slf_attn_bias because of no paddings.
return
trg_words
,
trg_pos
,
None
,
trg_src_attn_bias
,
enc_output
return
trg_words
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
enc_output
def
update_dec_in_data
(
dec_in_data
,
next_ids
,
active_beams
):
"""
...
...
@@ -79,6 +86,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
input data and dropping the finished instance beams.
"""
trg_words
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
enc_output
=
dec_in_data
trg_cur_len
=
len
(
next_ids
[
0
])
+
1
# include the <bos>
trg_words
=
np
.
array
(
[
beam_backtrace
(
...
...
@@ -88,14 +96,22 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
dtype
=
"int64"
)
trg_words
=
trg_words
.
reshape
([
-
1
,
1
])
trg_pos
=
np
.
array
(
[
range
(
1
,
len
(
next_ids
[
0
])
+
2
)]
*
len
(
active_beams
)
*
beam_size
,
[
range
(
1
,
trg_cur_len
+
1
)]
*
len
(
active_beams
)
*
beam_size
,
dtype
=
"int64"
).
reshape
([
-
1
,
1
])
active_beams_indice
=
(
(
np
.
array
(
active_beams
)
*
beam_size
)[:,
np
.
newaxis
]
+
np
.
array
(
range
(
beam_size
))[
np
.
newaxis
,
:]).
flatten
()
# This is used to remove attention on subsequent words.
trg_slf_attn_bias
=
np
.
ones
((
len
(
active_beams
)
*
beam_size
,
trg_cur_len
,
trg_cur_len
))
trg_slf_attn_bias
=
np
.
triu
(
trg_slf_attn_bias
,
1
).
reshape
(
[
-
1
,
1
,
trg_cur_len
,
trg_cur_len
])
trg_slf_attn_bias
=
(
np
.
tile
(
trg_slf_attn_bias
,
[
1
,
n_head
,
1
,
1
])
*
[
-
1e9
]).
astype
(
"float32"
)
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias
=
np
.
tile
(
trg_src_attn_bias
[
active_beams_indice
,
:,
::
trg_src_attn_bias
.
shape
[
2
],
:],
[
1
,
1
,
len
(
next_ids
[
0
])
+
1
,
1
])
[
1
,
1
,
trg_cur_len
,
1
])
enc_output
=
enc_output
[
active_beams_indice
,
:,
:]
return
trg_words
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
enc_output
...
...
@@ -103,9 +119,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
enc_output
)
for
i
in
range
(
max_length
):
predict_all
=
exe
.
run
(
decoder
,
feed
=
dict
(
filter
(
lambda
item
:
item
[
1
]
is
not
None
,
zip
(
dec_in_names
,
dec_in_data
))),
feed
=
dict
(
zip
(
dec_in_names
,
dec_in_data
)),
fetch_list
=
dec_out_names
)[
0
]
predict_all
=
np
.
log
(
predict_all
)
predict_all
=
(
...
...
@@ -206,9 +220,9 @@ def main():
encoder_input_data_names
,
[
enc_output
.
name
],
decoder_program
,
decoder_input_data_names
,
[
predict
.
name
],
InferTaskConfig
.
beam_size
,
InferTaskConfig
.
max_length
,
InferTaskConfig
.
n_best
,
InferTaskConfig
.
batch_size
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
src_pad_idx
,
ModelHyperParams
.
trg_pad
_idx
,
ModelHyperParams
.
bos_idx
,
ModelHyperParams
.
eos_idx
)
len
(
data
),
ModelHyperParams
.
n_head
,
ModelHyperParams
.
src_pad_idx
,
ModelHyperParams
.
trg_pad_idx
,
ModelHyperParams
.
bos
_idx
,
ModelHyperParams
.
eos_idx
)
for
i
in
range
(
len
(
batch_seqs
)):
seqs
=
batch_seqs
[
i
]
scores
=
batch_scores
[
i
]
...
...
fluid/neural_machine_translation/transformer/model.py
浏览文件 @
6ef54e8e
...
...
@@ -283,8 +283,15 @@ def encoder(enc_input,
encoder_layer.
"""
for
i
in
range
(
n_layer
):
enc_output
=
encoder_layer
(
enc_input
,
attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
)
enc_output
=
encoder_layer
(
enc_input
,
attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
dropout_rate
,
)
enc_input
=
enc_output
return
enc_output
...
...
@@ -381,9 +388,10 @@ def make_inputs(input_data_names,
d_model
,
batch_size
,
max_length
,
is_pos
,
slf_attn_bias_flag
,
src_attn_bias_flag
,
pos_flag
=
1
):
enc_output_flag
=
False
):
"""
Define the input data layers for the transformer model.
"""
...
...
@@ -391,35 +399,43 @@ def make_inputs(input_data_names,
# The shapes here act as placeholder.
# The shapes set here is to pass the infer-shape in compile time.
word
=
layers
.
data
(
name
=
input_data_names
[
0
],
name
=
input_data_names
[
len
(
input_layers
)
],
shape
=
[
batch_size
*
max_length
,
1
],
dtype
=
"int64"
,
append_batch_size
=
False
)
input_layers
+=
[
word
]
# This is used for position data or label weight.
pos
=
layers
.
data
(
name
=
input_data_names
[
1
],
name
=
input_data_names
[
len
(
input_layers
)
],
shape
=
[
batch_size
*
max_length
,
1
],
dtype
=
"int64"
if
pos_flag
else
"float32"
,
dtype
=
"int64"
if
is_pos
else
"float32"
,
append_batch_size
=
False
)
input_layers
+=
[
pos
]
if
slf_attn_bias_flag
:
# This is used for attention bias or encoder output.
# This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
slf_attn_bias
=
layers
.
data
(
name
=
input_data_names
[
2
]
if
slf_attn_bias_flag
==
1
else
input_data_names
[
-
1
],
shape
=
[
batch_size
,
n_head
,
max_length
,
max_length
]
if
slf_attn_bias_flag
==
1
else
[
batch_size
,
max_length
,
d_model
],
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
batch_size
,
n_head
,
max_length
,
max_length
],
dtype
=
"float32"
,
append_batch_size
=
False
)
input_layers
+=
[
slf_attn_bias
]
if
src_attn_bias_flag
:
# This input is used to remove attention weights on paddings.
src_attn_bias
=
layers
.
data
(
name
=
input_data_names
[
3
],
name
=
input_data_names
[
len
(
input_layers
)
],
shape
=
[
batch_size
,
n_head
,
max_length
,
max_length
],
dtype
=
"float32"
,
append_batch_size
=
False
)
input_layers
+=
[
src_attn_bias
]
if
enc_output_flag
:
enc_output
=
layers
.
data
(
name
=
input_data_names
[
len
(
input_layers
)],
shape
=
[
batch_size
,
max_length
,
d_model
],
dtype
=
"float32"
,
append_batch_size
=
False
)
input_layers
+=
[
enc_output
]
return
input_layers
...
...
@@ -438,7 +454,7 @@ def transformer(
trg_pad_idx
,
pos_pad_idx
,
):
enc_input_layers
=
make_inputs
(
encoder_input_data_names
,
n_head
,
d_model
,
batch_size
,
max_length
,
1
,
0
)
batch_size
,
max_length
,
True
,
True
,
False
)
enc_output
=
wrap_encoder
(
src_vocab_size
,
...
...
@@ -455,7 +471,7 @@ def transformer(
enc_input_layers
,
)
dec_input_layers
=
make_inputs
(
decoder_input_data_names
,
n_head
,
d_model
,
batch_size
,
max_length
,
1
,
1
)
batch_size
,
max_length
,
True
,
True
,
True
)
predict
=
wrap_decoder
(
trg_vocab_size
,
...
...
@@ -475,7 +491,7 @@ def transformer(
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold
,
weights
=
make_inputs
(
label_data_names
,
n_head
,
d_model
,
batch_size
,
max_length
,
0
,
0
,
0
)
max_length
,
False
,
False
,
False
)
cost
=
layers
.
cross_entropy
(
input
=
predict
,
label
=
gold
)
weighted_cost
=
cost
*
weights
return
layers
.
reduce_sum
(
weighted_cost
),
predict
...
...
@@ -500,7 +516,7 @@ def wrap_encoder(src_vocab_size,
# This is used to implement independent encoder program in inference.
src_word
,
src_pos
,
src_slf_attn_bias
=
make_inputs
(
encoder_input_data_names
,
n_head
,
d_model
,
batch_size
,
max_length
,
True
,
False
)
True
,
True
,
False
)
else
:
src_word
,
src_pos
,
src_slf_attn_bias
=
enc_input_layers
enc_input
=
prepare_encoder
(
...
...
@@ -542,11 +558,9 @@ def wrap_decoder(trg_vocab_size,
"""
if
dec_input_layers
is
None
:
# This is used to implement independent decoder program in inference.
# No need for trg_slf_attn_bias because of no paddings in inference.
trg_word
,
trg_pos
,
enc_output
,
trg_src_attn_bias
=
make_inputs
(
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
enc_output
=
make_inputs
(
decoder_input_data_names
,
n_head
,
d_model
,
batch_size
,
max_length
,
2
,
1
)
trg_slf_attn_bias
=
None
True
,
True
,
True
,
True
)
else
:
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
=
dec_input_layers
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录