Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
12fb5614
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
12fb5614
编写于
4月 21, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine PPL and reader for seq2seq.
上级
833a0157
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
138 addition
and
149 deletion
+138
-149
seq2seq/README.md
seq2seq/README.md
+1
-1
seq2seq/args.py
seq2seq/args.py
+6
-0
seq2seq/reader.py
seq2seq/reader.py
+2
-2
seq2seq/seq2seq.yaml
seq2seq/seq2seq.yaml
+0
-83
seq2seq/seq2seq_base.py
seq2seq/seq2seq_base.py
+41
-0
seq2seq/train.py
seq2seq/train.py
+8
-63
seq2seq/utility.py
seq2seq/utility.py
+80
-0
未找到文件。
seq2seq/README.md
浏览文件 @
12fb5614
...
@@ -151,7 +151,7 @@ python infer.py \
...
@@ -151,7 +151,7 @@ python infer.py \
--reload_model
attention_models/epoch_10
\
--reload_model
attention_models/epoch_10
\
--infer_output_file
attention_infer_output/infer_output.txt
\
--infer_output_file
attention_infer_output/infer_output.txt
\
--beam_size
10
\
--beam_size
10
\
--use_gpu
True
--use_gpu
True
\
--eager_run
False
--eager_run
False
```
```
...
...
seq2seq/args.py
浏览文件 @
12fb5614
...
@@ -88,6 +88,12 @@ def parse_args():
...
@@ -88,6 +88,12 @@ def parse_args():
default
=
5.0
,
default
=
5.0
,
help
=
"max grad norm for global norm clip"
)
help
=
"max grad norm for global norm clip"
)
parser
.
add_argument
(
"--log_freq"
,
type
=
int
,
default
=
100
,
help
=
"The frequency to print training logs"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_path"
,
"--model_path"
,
type
=
str
,
type
=
str
,
...
...
seq2seq/reader.py
浏览文件 @
12fb5614
...
@@ -168,7 +168,7 @@ class SampleInfo(object):
...
@@ -168,7 +168,7 @@ class SampleInfo(object):
def
__init__
(
self
,
i
,
lens
):
def
__init__
(
self
,
i
,
lens
):
self
.
i
=
i
self
.
i
=
i
self
.
lens
=
lens
self
.
lens
=
lens
self
.
max_len
=
lens
[
0
]
self
.
max_len
=
lens
[
0
]
# to be consitent with the original reader
def
get_ranges
(
self
,
min_length
=
None
,
max_length
=
None
,
truncate
=
False
):
def
get_ranges
(
self
,
min_length
=
None
,
max_length
=
None
,
truncate
=
False
):
ranges
=
[]
ranges
=
[]
...
@@ -379,7 +379,7 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -379,7 +379,7 @@ class Seq2SeqBatchSampler(BatchSampler):
reverse
=
True
reverse
=
True
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
# to avoid placing short next to long sentences
# to avoid placing short next to long sentences
reverse
=
not
reverse
reverse
=
False
#
not reverse
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
,
key
=
lambda
x
:
x
.
max_len
,
...
...
seq2seq/seq2seq.yaml
已删除
100644 → 0
浏览文件 @
833a0157
# used for continuous evaluation
enable_ce
:
False
eager_run
:
False
# The frequency to save trained models when training.
save_step
:
10000
# The frequency to fetch and print output when training.
print_step
:
100
# path of the checkpoint, to resume the previous training
init_from_checkpoint
:
"
"
# path of the pretrain model, to better solve the current task
init_from_pretrain_model
:
"
"
# path of trained parameter, to make prediction
init_from_params
:
"
trained_params/step_100000/"
# the directory for saving model
save_model
:
"
trained_models"
# the directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
random_seed
:
None
# The pattern to match training data files.
training_file
:
"
wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match validation data files.
validation_file
:
"
wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de"
# The pattern to match test data files.
predict_file
:
"
wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
output_file
:
"
predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath
:
"
wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The path of vocabulary file of target language.
trg_vocab_fpath
:
"
wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
# max length of sequences
max_length
:
256
# whether to use cuda
use_cuda
:
True
# args for reader, see reader.py for details
token_delimiter
:
"
"
use_token_batch
:
True
pool_size
:
200000
sort_type
:
"
pool"
shuffle
:
True
shuffle_batch
:
True
batch_size
:
4096
# Hyparams for training:
# the number of epoches for training
epoch
:
30
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
:
0.001
# Hyparams for generation:
# the parameters for beam search.
beam_size
:
5
max_out_len
:
256
# the number of decoded sentences to output.
n_best
:
1
# Hyparams for model:
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size
:
10000
# size of target word dictionay
trg_vocab_size
:
10000
# index for <bos> token
bos_idx
:
0
# index for <eos> token
eos_idx
:
1
# index for <unk> token
unk_idx
:
2
embed_dim
:
512
hidden_size
:
512
num_layers
:
2
dropout
:
0.1
seq2seq/seq2seq_base.py
浏览文件 @
12fb5614
...
@@ -200,3 +200,44 @@ class BaseInferModel(BaseModel):
...
@@ -200,3 +200,44 @@ class BaseInferModel(BaseModel):
# dynamic decoding with beam search
# dynamic decoding with beam search
rs
,
_
=
self
.
beam_search_decoder
(
inits
=
encoder_final_states
)
rs
,
_
=
self
.
beam_search_decoder
(
inits
=
encoder_final_states
)
return
rs
return
rs
class
BaseGreedyInferModel
(
BaseModel
):
def
__init__
(
self
,
src_vocab_size
,
trg_vocab_size
,
embed_dim
,
hidden_size
,
num_layers
,
dropout_prob
=
0.
,
bos_id
=
0
,
eos_id
=
1
,
beam_size
=
1
,
max_out_len
=
256
):
args
=
dict
(
locals
())
args
.
pop
(
"self"
)
args
.
pop
(
"__class__"
,
None
)
# py3
args
.
pop
(
"beam_size"
,
None
)
self
.
bos_id
=
args
.
pop
(
"bos_id"
)
self
.
eos_id
=
args
.
pop
(
"eos_id"
)
self
.
max_out_len
=
args
.
pop
(
"max_out_len"
)
super
(
BaseGreedyInferModel
,
self
).
__init__
(
**
args
)
# dynamic decoder for inference
decoder_helper
=
GreedyEmbeddingHelper
(
start_tokens
=
bos_id
,
end_token
=
eos_id
,
embedding_fn
=
self
.
decoder
.
embedder
)
decoder
=
BasicDecoder
(
cell
=
self
.
decoder
.
stack_lstm
.
cell
,
helper
=
decoder_helper
,
output_fn
=
self
.
decoder
.
output_layer
)
self
.
greedy_search_decoder
=
DynamicDecode
(
decoder
,
max_step_num
=
max_out_len
,
is_test
=
True
)
def
forward
(
self
,
src
,
src_length
):
# encoding
encoder_output
,
encoder_final_states
=
self
.
encoder
(
src
,
src_length
)
# dynamic decoding with greedy search
rs
,
_
=
self
.
greedy_search_decoder
(
inits
=
encoder_final_states
)
return
rs
.
sample_ids
seq2seq/train.py
浏览文件 @
12fb5614
...
@@ -30,65 +30,7 @@ from args import parse_args
...
@@ -30,65 +30,7 @@ from args import parse_args
from
seq2seq_base
import
BaseModel
,
CrossEntropyCriterion
from
seq2seq_base
import
BaseModel
,
CrossEntropyCriterion
from
seq2seq_attn
import
AttentionModel
from
seq2seq_attn
import
AttentionModel
from
reader
import
create_data_loader
from
reader
import
create_data_loader
from
utility
import
PPL
,
TrainCallback
class
TrainCallback
(
ProgBarLogger
):
def
__init__
(
self
,
args
,
ppl
,
verbose
=
2
):
super
(
TrainCallback
,
self
).
__init__
(
1
,
verbose
)
# control metric
self
.
ppl
=
ppl
self
.
batch_size
=
args
.
batch_size
def
on_train_begin
(
self
,
logs
=
None
):
super
(
TrainCallback
,
self
).
on_train_begin
(
logs
)
self
.
train_metrics
+=
[
"ppl"
]
# remove loss to not print it
self
.
ppl
.
reset
()
def
on_train_batch_end
(
self
,
step
,
logs
=
None
):
batch_loss
=
logs
[
"loss"
][
0
]
self
.
ppl
.
total_loss
+=
batch_loss
*
self
.
batch_size
logs
[
"ppl"
]
=
np
.
exp
(
self
.
ppl
.
total_loss
/
self
.
ppl
.
word_count
)
if
step
>
0
and
step
%
self
.
ppl
.
reset_freq
==
0
:
self
.
ppl
.
reset
()
super
(
TrainCallback
,
self
).
on_train_batch_end
(
step
,
logs
)
def
on_eval_begin
(
self
,
logs
=
None
):
super
(
TrainCallback
,
self
).
on_eval_begin
(
logs
)
self
.
eval_metrics
=
[
"ppl"
]
self
.
ppl
.
reset
()
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
batch_loss
=
logs
[
"loss"
][
0
]
self
.
ppl
.
total_loss
+=
batch_loss
*
self
.
batch_size
logs
[
"ppl"
]
=
np
.
exp
(
self
.
ppl
.
total_loss
/
self
.
ppl
.
word_count
)
super
(
TrainCallback
,
self
).
on_eval_batch_end
(
step
,
logs
)
class
PPL
(
Metric
):
def
__init__
(
self
,
reset_freq
=
100
,
name
=
None
):
super
(
PPL
,
self
).
__init__
()
self
.
_name
=
name
or
"ppl"
self
.
reset_freq
=
reset_freq
self
.
reset
()
def
add_metric_op
(
self
,
pred
,
label
):
seq_length
=
label
[
0
]
word_num
=
fluid
.
layers
.
reduce_sum
(
seq_length
)
return
word_num
def
update
(
self
,
word_num
):
self
.
word_count
+=
word_num
return
word_num
def
reset
(
self
):
self
.
total_loss
=
0
self
.
word_count
=
0
def
accumulate
(
self
):
return
self
.
word_count
def
name
(
self
):
return
self
.
_name
def
do_train
(
args
):
def
do_train
(
args
):
...
@@ -122,10 +64,13 @@ def do_train(args):
...
@@ -122,10 +64,13 @@ def do_train(args):
model
=
model_maker
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
model
=
model_maker
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
args
.
hidden_size
,
args
.
hidden_size
,
args
.
num_layers
,
args
.
hidden_size
,
args
.
hidden_size
,
args
.
num_layers
,
args
.
dropout
)
args
.
dropout
)
optimizer
=
fluid
.
optimizer
.
Adam
(
grad_clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
learning_rate
=
args
.
learning_rate
,
parameter_list
=
model
.
parameters
())
optimizer
.
_grad_clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
args
.
max_grad_norm
)
clip_norm
=
args
.
max_grad_norm
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
args
.
learning_rate
,
parameter_list
=
model
.
parameters
(),
grad_clip
=
grad_clip
)
ppl_metric
=
PPL
()
ppl_metric
=
PPL
()
model
.
prepare
(
model
.
prepare
(
optimizer
,
optimizer
,
...
@@ -139,7 +84,7 @@ def do_train(args):
...
@@ -139,7 +84,7 @@ def do_train(args):
eval_freq
=
1
,
eval_freq
=
1
,
save_freq
=
1
,
save_freq
=
1
,
save_dir
=
args
.
model_path
,
save_dir
=
args
.
model_path
,
callbacks
=
[
TrainCallback
(
args
,
ppl_metric
)])
callbacks
=
[
TrainCallback
(
ppl_metric
,
args
.
log_freq
)])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
seq2seq/utility.py
0 → 100644
浏览文件 @
12fb5614
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle.fluid
as
fluid
from
metrics
import
Metric
from
callbacks
import
ProgBarLogger
class
TrainCallback
(
ProgBarLogger
):
def
__init__
(
self
,
ppl
,
log_freq
,
verbose
=
2
):
super
(
TrainCallback
,
self
).
__init__
(
log_freq
,
verbose
)
self
.
ppl
=
ppl
def
on_train_begin
(
self
,
logs
=
None
):
super
(
TrainCallback
,
self
).
on_train_begin
(
logs
)
self
.
train_metrics
=
[
"ppl"
]
# remove loss to not print it
def
on_epoch_begin
(
self
,
epoch
=
None
,
logs
=
None
):
super
(
TrainCallback
,
self
).
on_epoch_begin
(
epoch
,
logs
)
self
.
ppl
.
reset
()
def
on_train_batch_end
(
self
,
step
,
logs
=
None
):
logs
[
"ppl"
]
=
self
.
ppl
.
cal_acc_ppl
(
logs
[
"loss"
][
0
],
logs
[
"batch_size"
])
if
step
>
0
and
step
%
self
.
ppl
.
reset_freq
==
0
:
self
.
ppl
.
reset
()
super
(
TrainCallback
,
self
).
on_train_batch_end
(
step
,
logs
)
def
on_eval_begin
(
self
,
logs
=
None
):
super
(
TrainCallback
,
self
).
on_eval_begin
(
logs
)
self
.
eval_metrics
=
[
"ppl"
]
self
.
ppl
.
reset
()
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
logs
[
"ppl"
]
=
self
.
ppl
.
cal_acc_ppl
(
logs
[
"loss"
][
0
],
logs
[
"batch_size"
])
super
(
TrainCallback
,
self
).
on_eval_batch_end
(
step
,
logs
)
class
PPL
(
Metric
):
def
__init__
(
self
,
reset_freq
=
100
,
name
=
None
):
super
(
PPL
,
self
).
__init__
()
self
.
_name
=
name
or
"ppl"
self
.
reset_freq
=
reset_freq
self
.
reset
()
def
add_metric_op
(
self
,
pred
,
label
):
seq_length
=
label
[
0
]
word_num
=
fluid
.
layers
.
reduce_sum
(
seq_length
)
return
word_num
def
update
(
self
,
word_num
):
self
.
word_count
+=
word_num
[
0
]
return
word_num
def
reset
(
self
):
self
.
total_loss
=
0
self
.
word_count
=
0
def
accumulate
(
self
):
return
self
.
word_count
def
name
(
self
):
return
self
.
_name
def
cal_acc_ppl
(
self
,
batch_loss
,
batch_size
):
self
.
total_loss
+=
batch_loss
*
batch_size
ppl
=
np
.
exp
(
self
.
total_loss
/
self
.
word_count
)
return
ppl
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录