Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
3500061d
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3500061d
编写于
4月 08, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add seq2seq infer
上级
27afc286
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
155 addition
and
15 deletion
+155
-15
seq2seq/predict.py
seq2seq/predict.py
+126
-0
seq2seq/reader.py
seq2seq/reader.py
+5
-0
seq2seq/seq2seq_attn.py
seq2seq/seq2seq_attn.py
+8
-2
seq2seq/seq2seq_base.py
seq2seq/seq2seq_base.py
+8
-2
seq2seq/train.py
seq2seq/train.py
+8
-11
未找到文件。
seq2seq/predict.py
浏览文件 @
3500061d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
import
io
import
sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
import
random
from
functools
import
partial
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.layers.utils
import
flatten
from
paddle.fluid.io
import
DataLoader
from
model
import
Input
,
set_device
from
args
import
parse_args
from
seq2seq_base
import
BaseInferModel
from
seq2seq_attn
import
AttentionInferModel
from
reader
import
Seq2SeqDataset
,
Seq2SeqBatchSampler
,
SortType
,
prepare_infer_input
def
post_process_seq
(
seq
,
bos_idx
,
eos_idx
,
output_bos
=
False
,
output_eos
=
False
):
"""
Post-process the decoded sequence.
"""
eos_pos
=
len
(
seq
)
-
1
for
i
,
idx
in
enumerate
(
seq
):
if
idx
==
eos_idx
:
eos_pos
=
i
break
seq
=
[
idx
for
idx
in
seq
[:
eos_pos
+
1
]
if
(
output_bos
or
idx
!=
bos_idx
)
and
(
output_eos
or
idx
!=
eos_idx
)
]
return
seq
def
do_predict
(
args
):
device
=
set_device
(
"gpu"
if
args
.
use_gpu
else
"cpu"
)
fluid
.
enable_dygraph
(
device
)
if
args
.
eager_run
else
None
# define model
inputs
=
[
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_word"
),
Input
(
[
None
],
"int64"
,
name
=
"src_length"
),
]
# def dataloader
dataset
=
Seq2SeqDataset
(
fpattern
=
args
.
infer_file
,
src_vocab_fpath
=
args
.
vocab_prefix
+
"."
+
args
.
src_lang
,
trg_vocab_fpath
=
args
.
vocab_prefix
+
"."
+
args
.
tar_lang
,
token_delimiter
=
None
,
start_mark
=
"<s>"
,
end_mark
=
"</s>"
,
unk_mark
=
"<unk>"
)
trg_idx2word
=
Seq2SeqDataset
.
load_dict
(
dict_path
=
args
.
vocab_prefix
+
"."
+
args
.
tar_lang
,
reverse
=
True
)
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
bos_id
,
eos_id
,
unk_id
)
=
dataset
.
get_vocab_summary
()
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
False
,
batch_size
=
args
.
batch_size
)
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
feed_list
=
None
if
fluid
.
in_dygraph_mode
()
else
[
x
.
forward
()
for
x
in
inputs
],
collate_fn
=
partial
(
prepare_infer_input
,
bos_id
=
bos_id
,
eos_id
=
eos_id
,
pad_id
=
eos_id
),
num_workers
=
0
,
return_list
=
True
)
model_maker
=
AttentionInferModel
if
args
.
attention
else
BaseInferModel
model
=
model_maker
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
args
.
hidden_size
,
args
.
hidden_size
,
args
.
num_layers
,
args
.
dropout
,
bos_id
=
bos_id
,
eos_id
=
eos_id
,
beam_size
=
args
.
beam_size
,
max_out_len
=
256
)
model
.
prepare
(
inputs
=
inputs
)
# load the trained model
assert
args
.
reload_model
,
(
"Please set reload_model to load the infer model."
)
model
.
load
(
args
.
reload_model
)
# TODO(guosheng): use model.predict when support variant length
with
io
.
open
(
args
.
infer_output_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
data
in
data_loader
():
finished_seq
=
model
.
test
(
inputs
=
flatten
(
data
))[
0
]
finished_seq
=
np
.
transpose
(
finished_seq
,
[
0
,
2
,
1
])
for
ins
in
finished_seq
:
for
beam_idx
,
beam
in
enumerate
(
ins
):
id_list
=
post_process_seq
(
beam
,
bos_id
,
eos_id
)
word_list
=
[
trg_idx2word
[
id
]
for
id
in
id_list
]
sequence
=
" "
.
join
(
word_list
)
+
"
\n
"
f
.
write
(
sequence
)
break
if
__name__
==
"__main__"
:
args
=
parse_args
()
do_predict
(
args
)
seq2seq/reader.py
浏览文件 @
3500061d
...
...
@@ -33,6 +33,11 @@ def prepare_train_input(insts, bos_id, eos_id, pad_id):
return
src
,
src_length
,
trg
[:,
:
-
1
],
trg_length
,
trg
[:,
1
:,
np
.
newaxis
]
def
prepare_infer_input
(
insts
,
bos_id
,
eos_id
,
pad_id
):
src
,
src_length
=
pad_batch_data
(
insts
,
pad_id
=
pad_id
)
return
src
,
src_length
def
pad_batch_data
(
insts
,
pad_id
):
"""
Pad the instances to the max sequence length in batch, and generate the
...
...
seq2seq/seq2seq_attn.py
浏览文件 @
3500061d
...
...
@@ -90,7 +90,10 @@ class DecoderCell(RNNCell):
for
i
,
lstm_cell
in
enumerate
(
self
.
lstm_cells
):
out
,
new_lstm_state
=
lstm_cell
(
step_input
,
lstm_states
[
i
])
step_input
=
layers
.
dropout
(
out
,
self
.
dropout_prob
)
if
self
.
dropout_prob
>
0
else
out
out
,
self
.
dropout_prob
,
dropout_implementation
=
'upscale_in_train'
)
if
self
.
dropout_prob
>
0
else
out
new_lstm_states
.
append
(
new_lstm_state
)
out
=
self
.
attention_layer
(
step_input
,
encoder_output
,
encoder_padding_mask
)
...
...
@@ -180,7 +183,8 @@ class AttentionModel(Model):
class
AttentionInferModel
(
AttentionModel
):
def
__init__
(
self
,
vocab_size
,
src_vocab_size
,
trg_vocab_size
,
embed_dim
,
hidden_size
,
num_layers
,
...
...
@@ -192,6 +196,8 @@ class AttentionInferModel(AttentionModel):
args
=
dict
(
locals
())
args
.
pop
(
"self"
)
args
.
pop
(
"__class__"
,
None
)
# py3
self
.
bos_id
=
args
.
pop
(
"bos_id"
)
self
.
eos_id
=
args
.
pop
(
"eos_id"
)
self
.
beam_size
=
args
.
pop
(
"beam_size"
)
self
.
max_out_len
=
args
.
pop
(
"max_out_len"
)
super
(
AttentionInferModel
,
self
).
__init__
(
**
args
)
...
...
seq2seq/seq2seq_base.py
浏览文件 @
3500061d
...
...
@@ -63,7 +63,10 @@ class EncoderCell(RNNCell):
for
i
,
lstm_cell
in
enumerate
(
self
.
lstm_cells
):
out
,
new_state
=
lstm_cell
(
step_input
,
states
[
i
])
step_input
=
layers
.
dropout
(
out
,
self
.
dropout_prob
)
if
self
.
dropout_prob
>
0
else
out
out
,
self
.
dropout_prob
,
dropout_implementation
=
'upscale_in_train'
)
if
self
.
dropout_prob
>
0
else
out
new_states
.
append
(
new_state
)
return
step_input
,
new_states
...
...
@@ -163,7 +166,8 @@ class BaseModel(Model):
class
BaseInferModel
(
BaseModel
):
def
__init__
(
self
,
vocab_size
,
src_vocab_size
,
trg_vocab_size
,
embed_dim
,
hidden_size
,
num_layers
,
...
...
@@ -175,6 +179,8 @@ class BaseInferModel(BaseModel):
args
=
dict
(
locals
())
args
.
pop
(
"self"
)
args
.
pop
(
"__class__"
,
None
)
# py3
self
.
bos_id
=
args
.
pop
(
"bos_id"
)
self
.
eos_id
=
args
.
pop
(
"eos_id"
)
self
.
beam_size
=
args
.
pop
(
"beam_size"
)
self
.
max_out_len
=
args
.
pop
(
"max_out_len"
)
super
(
BaseInferModel
,
self
).
__init__
(
**
args
)
...
...
seq2seq/train.py
浏览文件 @
3500061d
...
...
@@ -14,25 +14,20 @@
import
logging
import
os
import
six
import
sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
import
random
from
functools
import
partial
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
to_variable
from
paddle.fluid.io
import
DataLoader
from
paddle.fluid.dygraph_grad_clip
import
GradClipByGlobalNorm
import
reader
from
model
import
Input
,
set_device
from
callbacks
import
ProgBarLogger
from
args
import
parse_args
from
seq2seq_base
import
BaseModel
,
CrossEntropyCriterion
from
seq2seq_attn
import
AttentionModel
from
model
import
Input
,
set_device
from
callbacks
import
ProgBarLogger
from
reader
import
Seq2SeqDataset
,
Seq2SeqBatchSampler
,
SortType
,
prepare_train_input
...
...
@@ -97,7 +92,8 @@ def do_train(args):
data_loaders
[
i
]
=
data_loader
train_loader
,
eval_loader
=
data_loaders
model
=
AttentionModel
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
model_maker
=
AttentionModel
if
args
.
attention
else
BaseModel
model
=
model_maker
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
args
.
hidden_size
,
args
.
hidden_size
,
args
.
num_layers
,
args
.
dropout
)
...
...
@@ -110,9 +106,10 @@ def do_train(args):
labels
=
labels
)
model
.
fit
(
train_data
=
train_loader
,
eval_data
=
eval_loader
,
epochs
=
1
,
epochs
=
args
.
max_epoch
,
eval_freq
=
1
,
save_freq
=
1
,
save_dir
=
args
.
model_path
,
log_freq
=
1
,
verbose
=
2
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录