Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
b2eb5149
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b2eb5149
编写于
2月 21, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add beam search for Transformer.
上级
2a0991e1
变更
6
展开全部
显示空白变更内容
内联
并排
Showing
6 changed file
with
963 addition
and
513 deletion
+963
-513
transformer/predict.py
transformer/predict.py
+44
-45
transformer/reader.py
transformer/reader.py
+3
-3
transformer/rnn_api.py
transformer/rnn_api.py
+778
-0
transformer/run.sh
transformer/run.sh
+41
-0
transformer/transformer.py
transformer/transformer.py
+97
-465
transformer_pr.tar.gz
transformer_pr.tar.gz
+0
-0
未找到文件。
transformer/predict.py
浏览文件 @
b2eb5149
...
...
@@ -16,7 +16,9 @@ import logging
import
os
import
six
import
sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
import
time
import
contextlib
import
numpy
as
np
import
paddle
...
...
@@ -27,10 +29,11 @@ from utils.check import check_gpu, check_version
# include task-specific libs
import
reader
from
model
import
Transformer
,
position_encoding_init
from
transformer
import
Infer
Transformer
,
position_encoding_init
def
post_process_seq
(
seq
,
bos_idx
,
eos_idx
,
output_bos
=
False
,
output_eos
=
False
):
def
post_process_seq
(
seq
,
bos_idx
,
eos_idx
,
output_bos
=
False
,
output_eos
=
False
):
"""
Post-process the decoded sequence.
"""
...
...
@@ -47,10 +50,13 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
def
do_predict
(
args
):
if
args
.
use_cuda
:
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
device_ids
=
list
(
range
(
args
.
num_devices
))
@
contextlib
.
contextmanager
def
null_guard
():
yield
guard
=
fluid
.
dygraph
.
guard
()
if
args
.
eager_run
else
null_guard
()
# define the data generator
processor
=
reader
.
DataProcessor
(
fpattern
=
args
.
predict_file
,
...
...
@@ -69,68 +75,61 @@ def do_predict(args):
unk_mark
=
args
.
special_token
[
2
],
max_length
=
args
.
max_length
,
n_head
=
args
.
n_head
)
batch_generator
=
processor
.
data_generator
(
phase
=
"predict"
,
place
=
place
)
batch_generator
=
processor
.
data_generator
(
phase
=
"predict"
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
processor
.
get_vocab_summary
()
trg_idx2word
=
reader
.
DataProcessor
.
load_dict
(
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
processor
.
get_vocab_summary
()
with
fluid
.
dygraph
.
guard
(
place
):
with
guard
:
# define data loader
test_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
10
)
test_loader
.
set_batch_generator
(
batch_generator
,
places
=
place
)
test_loader
=
batch_generator
# define model
transformer
=
Transformer
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
args
.
d_inner_hid
,
args
.
prepostprocess_dropout
,
args
.
attention_dropout
,
args
.
relu_dropout
,
args
.
preprocess_cmd
,
args
.
postprocess_cmd
,
args
.
weight_sharing
,
args
.
bos_idx
,
args
.
eos_idx
)
transformer
=
InferTransformer
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
args
.
d_inner_hid
,
args
.
prepostprocess_dropout
,
args
.
attention_dropout
,
args
.
relu_dropout
,
args
.
preprocess_cmd
,
args
.
postprocess_cmd
,
args
.
weight_sharing
,
args
.
bos_idx
,
args
.
eos_idx
,
beam_size
=
args
.
beam_size
,
max_out_len
=
args
.
max_out_len
)
# load the trained model
assert
args
.
init_from_params
,
(
"Please set init_from_params to load the infer model."
)
model_dict
,
_
=
fluid
.
load_dygraph
(
os
.
path
.
join
(
args
.
init_from_params
,
"transformer"
))
# to avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict
[
"encoder.pos_encoder.weight"
]
=
position_encoding_init
(
args
.
max_length
+
1
,
args
.
d_model
)
model_dict
[
"decoder.pos_encoder.weight"
]
=
position_encoding_init
(
args
.
max_length
+
1
,
args
.
d_model
)
transformer
.
load_dict
(
model_dict
)
# set evaluate mode
transformer
.
eval
()
transformer
.
load
(
os
.
path
.
join
(
args
.
init_from_params
,
"transformer"
))
f
=
open
(
args
.
output_file
,
"wb"
)
for
input_data
in
test_loader
():
(
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_src_attn_bias
)
=
input_data
finished_seq
,
finished_scores
=
transformer
.
beam_search
(
src_word
,
src_pos
,
finished_seq
=
transformer
.
test
(
inputs
=
(
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_src_attn_bias
,
bos_id
=
args
.
bos_idx
,
eos_id
=
args
.
eos_idx
,
beam_size
=
args
.
beam_size
,
max_len
=
args
.
max_out_len
)
finished_seq
=
finished_seq
.
numpy
()
finished_scores
=
finished_scores
.
numpy
()
trg_src_attn_bias
),
device
=
'gpu'
,
device_ids
=
device_ids
)[
0
]
finished_seq
=
np
.
transpose
(
finished_seq
,
[
0
,
2
,
1
])
for
ins
in
finished_seq
:
for
beam_idx
,
beam
in
enumerate
(
ins
):
if
beam_idx
>=
args
.
n_best
:
break
id_list
=
post_process_seq
(
beam
,
args
.
bos_idx
,
args
.
eos_idx
)
id_list
=
post_process_seq
(
beam
,
args
.
bos_idx
,
args
.
eos_idx
)
word_list
=
[
trg_idx2word
[
id
]
for
id
in
id_list
]
sequence
=
b
" "
.
join
(
word_list
)
+
b
"
\n
"
f
.
write
(
sequence
)
break
if
__name__
==
"__main__"
:
...
...
transformer/reader.py
浏览文件 @
b2eb5149
...
...
@@ -114,7 +114,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
return
data_inputs
def
prepare_infer_input
(
insts
,
src_pad_idx
,
bos_idx
,
n_head
,
place
):
def
prepare_infer_input
(
insts
,
src_pad_idx
,
bos_idx
,
n_head
):
"""
Put all padded data needed by beam search decoder into a list.
"""
...
...
@@ -517,7 +517,7 @@ class DataProcessor(object):
return
__impl__
def
data_generator
(
self
,
phase
,
place
=
None
):
def
data_generator
(
self
,
phase
):
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
src_pad_idx
=
trg_pad_idx
=
self
.
_eos_idx
...
...
@@ -540,7 +540,7 @@ class DataProcessor(object):
def
__for_predict__
():
for
data
in
data_reader
():
data_inputs
=
prepare_infer_input
(
data
,
src_pad_idx
,
bos_idx
,
n_head
,
place
)
n_head
)
yield
data_inputs
return
__for_train__
if
phase
==
"train"
else
__for_predict__
...
...
transformer/rnn_api.py
0 → 100644
浏览文件 @
b2eb5149
此差异已折叠。
点击以展开。
transformer/run.sh
0 → 100644
浏览文件 @
b2eb5149
python
-u
train.py
\
--epoch
30
\
--src_vocab_fpath
wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--training_file
wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny
\
--validation_file
wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
4096
\
--print_step
1
\
--use_cuda
True
\
--random_seed
1000
\
--save_step
10
\
--eager_run
True
#--init_from_pretrain_model base_model_dygraph/step_100000/ \
#--init_from_checkpoint trained_models/step_200/transformer
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
exit
echo
`
date
`
python
-u
predict.py
\
--src_vocab_fpath
wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--predict_file
wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
64
\
--init_from_params
base_model_dygraph/step_100000/
\
--beam_size
5
\
--max_out_len
255
\
--output_file
predict.txt
\
--eager_run
True
#--max_length 500 \
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
echo
`
date
`
\ No newline at end of file
transformer/transformer.py
浏览文件 @
b2eb5149
此差异已折叠。
点击以展开。
transformer_pr.tar.gz
0 → 100644
浏览文件 @
b2eb5149
文件已添加
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录