Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
80628bc6
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
80628bc6
编写于
8月 20, 2019
作者:
G
Guo Sheng
提交者:
GitHub
8月 20, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update transformer to be unified (#3119)
* Update transformer to be unified * To support CE when unifing Transformer
上级
e16b9cae
变更
17
展开全部
显示空白变更内容
内联
并排
Showing
17 changed file
with
2815 addition
and
1393 deletion
+2815
-1393
PaddleNLP/neural_machine_translation/transformer/.run_ce.sh
PaddleNLP/neural_machine_translation/transformer/.run_ce.sh
+10
-9
PaddleNLP/neural_machine_translation/transformer/README.md
PaddleNLP/neural_machine_translation/transformer/README.md
+209
-139
PaddleNLP/neural_machine_translation/transformer/_ce.py
PaddleNLP/neural_machine_translation/transformer/_ce.py
+4
-4
PaddleNLP/neural_machine_translation/transformer/config.py
PaddleNLP/neural_machine_translation/transformer/config.py
+0
-111
PaddleNLP/neural_machine_translation/transformer/desc.py
PaddleNLP/neural_machine_translation/transformer/desc.py
+89
-0
PaddleNLP/neural_machine_translation/transformer/infer.py
PaddleNLP/neural_machine_translation/transformer/infer.py
+0
-328
PaddleNLP/neural_machine_translation/transformer/inference_model.py
...neural_machine_translation/transformer/inference_model.py
+136
-0
PaddleNLP/neural_machine_translation/transformer/main.py
PaddleNLP/neural_machine_translation/transformer/main.py
+34
-0
PaddleNLP/neural_machine_translation/transformer/palm/__init__.py
...P/neural_machine_translation/transformer/palm/__init__.py
+0
-0
PaddleNLP/neural_machine_translation/transformer/palm/toolkit/__init__.py
..._machine_translation/transformer/palm/toolkit/__init__.py
+0
-0
PaddleNLP/neural_machine_translation/transformer/palm/toolkit/configure.py
...machine_translation/transformer/palm/toolkit/configure.py
+332
-0
PaddleNLP/neural_machine_translation/transformer/palm/toolkit/input_field.py
...chine_translation/transformer/palm/toolkit/input_field.py
+175
-0
PaddleNLP/neural_machine_translation/transformer/predict.py
PaddleNLP/neural_machine_translation/transformer/predict.py
+218
-0
PaddleNLP/neural_machine_translation/transformer/reader.py
PaddleNLP/neural_machine_translation/transformer/reader.py
+289
-78
PaddleNLP/neural_machine_translation/transformer/train.py
PaddleNLP/neural_machine_translation/transformer/train.py
+231
-724
PaddleNLP/neural_machine_translation/transformer/transformer.py
...NLP/neural_machine_translation/transformer/transformer.py
+977
-0
PaddleNLP/neural_machine_translation/transformer/transformer.yaml
...P/neural_machine_translation/transformer/transformer.yaml
+111
-0
未找到文件。
PaddleNLP/neural_machine_translation/transformer/.run_ce.sh
100755 → 100644
浏览文件 @
80628bc6
#!/bin/bash
#!/bin/bash
sed
-i
'$a\dropout_seed = 1000'
../../models/neural_machine_translation/transformer/desc.py
DATA_PATH
=
./dataset/wmt16
DATA_PATH
=
./dataset/wmt16
train
(){
train
(){
python
-u
train.py
\
python
-u
main.py
\
--do_train
True
\
--src_vocab_fpath
$DATA_PATH
/en_10000.dict
\
--src_vocab_fpath
$DATA_PATH
/en_10000.dict
\
--trg_vocab_fpath
$DATA_PATH
/de_10000.dict
\
--trg_vocab_fpath
$DATA_PATH
/de_10000.dict
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--train_file_pattern
$DATA_PATH
/wmt16/train
\
--training_file
$DATA_PATH
/wmt16/train
\
--val_file_pattern
$DATA_PATH
/wmt16/val
\
--use_token_batch
True
\
--use_token_batch
True
\
--batch_size
2048
\
--batch_size
2048
\
--sort_type
pool
\
--sort_type
pool
\
--pool_size
10000
\
--pool_size
10000
\
--print_step
1
\
--weight_sharing
False
\
--epoch
20
\
--enable_ce
True
\
--enable_ce
True
\
--
fetch_steps
1
\
--
random_seed
1000
\
weight_sharing False
\
--save_checkpoint
""
\
pass_num 20
--save_param
""
}
}
cudaid
=
${
transformer
:
=0
}
# use 0-th card as default
cudaid
=
${
transformer
:
=0
}
# use 0-th card as default
...
...
PaddleNLP/neural_machine_translation/transformer/README.md
浏览文件 @
80628bc6
此差异已折叠。
点击以展开。
PaddleNLP/neural_machine_translation/transformer/_ce.py
浏览文件 @
80628bc6
...
@@ -8,20 +8,20 @@ from kpi import CostKpi, DurationKpi, AccKpi
...
@@ -8,20 +8,20 @@ from kpi import CostKpi, DurationKpi, AccKpi
#### NOTE kpi.py should shared in models in some way!!!!
#### NOTE kpi.py should shared in models in some way!!!!
train_cost_card1_kpi
=
CostKpi
(
'train_cost_card1'
,
0.002
,
0
,
actived
=
True
)
train_cost_card1_kpi
=
CostKpi
(
'train_cost_card1'
,
0.002
,
0
,
actived
=
True
)
test_cost_card1_kpi
=
CostKpi
(
'test_cost_card1'
,
0.008
,
0
,
actived
=
True
)
#
test_cost_card1_kpi = CostKpi('test_cost_card1', 0.008, 0, actived=True)
train_duration_card1_kpi
=
DurationKpi
(
train_duration_card1_kpi
=
DurationKpi
(
'train_duration_card1'
,
0.006
,
0
,
actived
=
True
)
'train_duration_card1'
,
0.006
,
0
,
actived
=
True
)
train_cost_card4_kpi
=
CostKpi
(
'train_cost_card4'
,
0.001
,
0
,
actived
=
True
)
train_cost_card4_kpi
=
CostKpi
(
'train_cost_card4'
,
0.001
,
0
,
actived
=
True
)
test_cost_card4_kpi
=
CostKpi
(
'test_cost_card4'
,
0.001
,
0
,
actived
=
True
)
#
test_cost_card4_kpi = CostKpi('test_cost_card4', 0.001, 0, actived=True)
train_duration_card4_kpi
=
DurationKpi
(
train_duration_card4_kpi
=
DurationKpi
(
'train_duration_card4'
,
0.02
,
0
,
actived
=
True
)
'train_duration_card4'
,
0.02
,
0
,
actived
=
True
)
tracking_kpis
=
[
tracking_kpis
=
[
train_cost_card1_kpi
,
train_cost_card1_kpi
,
test_cost_card1_kpi
,
#
test_cost_card1_kpi,
train_duration_card1_kpi
,
train_duration_card1_kpi
,
train_cost_card4_kpi
,
train_cost_card4_kpi
,
test_cost_card4_kpi
,
#
test_cost_card4_kpi,
train_duration_card4_kpi
,
train_duration_card4_kpi
,
]
]
...
...
PaddleNLP/neural_machine_translation/transformer/config.py
已删除
100644 → 0
浏览文件 @
e16b9cae
class
TrainTaskConfig
(
object
):
# support both CPU and GPU now.
use_gpu
=
True
# the epoch number to train.
pass_num
=
30
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size
=
32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
=
2.0
beta1
=
0.9
beta2
=
0.997
eps
=
1e-9
# the parameters for learning rate scheduling.
warmup_steps
=
8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps
=
0.1
# the directory for saving trained models.
model_dir
=
"trained_models"
# the directory for saving checkpoints.
ckpt_dir
=
"trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path
=
None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step
=
0
# the frequency to save trained models.
save_freq
=
10000
class
InferTaskConfig
(
object
):
use_gpu
=
True
# the number of examples in one run for sequence generation.
batch_size
=
10
# the parameters for beam search.
beam_size
=
5
max_out_len
=
256
# the number of decoded sentences to output.
n_best
=
1
# the flags indicating whether to output the special tokens.
output_bos
=
False
output_eos
=
False
output_unk
=
True
# the directory for loading the trained model.
model_path
=
"trained_models/pass_1.infer.model"
class
ModelHyperParams
(
object
):
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size
=
10000
# size of target word dictionay
trg_vocab_size
=
10000
# index for <bos> token
bos_idx
=
0
# index for <eos> token
eos_idx
=
1
# index for <unk> token
unk_idx
=
2
# max length of sequences deciding the size of position encoding table.
max_length
=
256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
=
512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
=
2048
# the dimension that keys are projected to for dot-product attention.
d_key
=
64
# the dimension that values are projected to for dot-product attention.
d_value
=
64
# number of head used in multi-head attention.
n_head
=
8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer
=
6
# dropout rates of different modules.
prepostprocess_dropout
=
0.1
attention_dropout
=
0.1
relu_dropout
=
0.1
# to process before each sub-layer
preprocess_cmd
=
"n"
# layer normalization
# to process after each sub-layer
postprocess_cmd
=
"da"
# dropout + residual connection
# random seed used in dropout for CE.
dropout_seed
=
None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing
=
True
def
merge_cfg_from_list
(
cfg_list
,
g_cfgs
):
"""
Set the above global configurations using the cfg_list.
"""
assert
len
(
cfg_list
)
%
2
==
0
for
key
,
value
in
zip
(
cfg_list
[
0
::
2
],
cfg_list
[
1
::
2
]):
for
g_cfg
in
g_cfgs
:
if
hasattr
(
g_cfg
,
key
):
try
:
value
=
eval
(
value
)
except
Exception
:
# for file path
pass
setattr
(
g_cfg
,
key
,
value
)
break
PaddleNLP/neural_machine_translation/transformer/desc.py
0 → 100644
浏览文件 @
80628bc6
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size
=
-
1
# The placeholder for squence length in compile time.
seq_len
=
256
# The placeholder for head number in compile time.
n_head
=
8
# The placeholder for model dim in compile time.
d_model
=
512
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs
=
{
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word"
:
[(
batch_size
,
seq_len
,
1
),
"int64"
,
2
],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos"
:
[(
batch_size
,
seq_len
,
1
),
"int64"
],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word"
:
[(
batch_size
,
seq_len
,
1
),
"int64"
,
2
],
# lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos"
:
[(
batch_size
,
seq_len
,
1
),
"int64"
],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output"
:
[(
batch_size
,
seq_len
,
d_model
),
"float32"
],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word"
:
[(
batch_size
*
seq_len
,
1
),
"int64"
],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight"
:
[(
batch_size
*
seq_len
,
1
),
"float32"
],
# This input is used in beam-search decoder.
"init_score"
:
[(
batch_size
,
1
),
"float32"
,
2
],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx"
:
[(
batch_size
,
),
"int32"
],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names
=
(
"src_word_emb_table"
,
"trg_word_emb_table"
,
)
# Names of position encoding table which will be initialized externally.
pos_enc_param_names
=
(
"src_pos_enc_table"
,
"trg_pos_enc_table"
,
)
# separated inputs for different usages.
encoder_data_input_fields
=
(
"src_word"
,
"src_pos"
,
"src_slf_attn_bias"
,
)
decoder_data_input_fields
=
(
"trg_word"
,
"trg_pos"
,
"trg_slf_attn_bias"
,
"trg_src_attn_bias"
,
"enc_output"
,
)
label_data_input_fields
=
(
"lbl_word"
,
"lbl_weight"
,
)
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields
=
(
"trg_word"
,
"init_score"
,
"init_idx"
,
"trg_src_attn_bias"
,
)
PaddleNLP/neural_machine_translation/transformer/infer.py
已删除
100644 → 0
浏览文件 @
e16b9cae
import
argparse
import
ast
import
multiprocessing
import
numpy
as
np
import
os
import
sys
sys
.
path
.
append
(
"../../"
)
sys
.
path
.
append
(
"../../models/neural_machine_translation/transformer/"
)
from
functools
import
partial
import
paddle
import
paddle.fluid
as
fluid
from
models.model_check
import
check_cuda
import
reader
from
config
import
*
from
desc
import
*
from
model
import
fast_decode
as
fast_decoder
from
train
import
pad_batch_data
,
prepare_data_generator
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Training for Transformer."
)
parser
.
add_argument
(
"--src_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of source language."
)
parser
.
add_argument
(
"--trg_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of target language."
)
parser
.
add_argument
(
"--test_file_pattern"
,
type
=
str
,
required
=
True
,
help
=
"The pattern to match test data files."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
50
,
help
=
"The number of examples in one run for sequence generation."
)
parser
.
add_argument
(
"--pool_size"
,
type
=
int
,
default
=
10000
,
help
=
"The buffer size to pool data."
)
parser
.
add_argument
(
"--special_token"
,
type
=
str
,
default
=
[
"<s>"
,
"<e>"
,
"<unk>"
],
nargs
=
3
,
help
=
"The <bos>, <eos> and <unk> tokens in the dictionary."
)
parser
.
add_argument
(
"--token_delimiter"
,
type
=
lambda
x
:
str
(
x
.
encode
().
decode
(
"unicode-escape"
)),
default
=
" "
,
help
=
"The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. "
)
parser
.
add_argument
(
"--use_mem_opt"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"The flag indicating whether to use memory optimization."
)
parser
.
add_argument
(
"--use_py_reader"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"The flag indicating whether to use py_reader."
)
parser
.
add_argument
(
"--use_parallel_exe"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"The flag indicating whether to use ParallelExecutor."
)
parser
.
add_argument
(
'opts'
,
help
=
'See config.py for all options'
,
default
=
None
,
nargs
=
argparse
.
REMAINDER
)
args
=
parser
.
parse_args
()
# Append args related to dict
src_dict
=
reader
.
DataReader
.
load_dict
(
args
.
src_vocab_fpath
)
trg_dict
=
reader
.
DataReader
.
load_dict
(
args
.
trg_vocab_fpath
)
dict_args
=
[
"src_vocab_size"
,
str
(
len
(
src_dict
)),
"trg_vocab_size"
,
str
(
len
(
trg_dict
)),
"bos_idx"
,
str
(
src_dict
[
args
.
special_token
[
0
]]),
"eos_idx"
,
str
(
src_dict
[
args
.
special_token
[
1
]]),
"unk_idx"
,
str
(
src_dict
[
args
.
special_token
[
2
]])
]
merge_cfg_from_list
(
args
.
opts
+
dict_args
,
[
InferTaskConfig
,
ModelHyperParams
])
return
args
def
post_process_seq
(
seq
,
bos_idx
=
ModelHyperParams
.
bos_idx
,
eos_idx
=
ModelHyperParams
.
eos_idx
,
output_bos
=
InferTaskConfig
.
output_bos
,
output_eos
=
InferTaskConfig
.
output_eos
):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos
=
len
(
seq
)
-
1
for
i
,
idx
in
enumerate
(
seq
):
if
idx
==
eos_idx
:
eos_pos
=
i
break
seq
=
[
idx
for
idx
in
seq
[:
eos_pos
+
1
]
if
(
output_bos
or
idx
!=
bos_idx
)
and
(
output_eos
or
idx
!=
eos_idx
)
]
return
seq
def
prepare_batch_input
(
insts
,
data_input_names
,
src_pad_idx
,
bos_idx
,
n_head
,
d_model
,
place
):
"""
Put all padded data needed by beam search decoder into a dict.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
# start tokens
trg_word
=
np
.
asarray
([[
bos_idx
]]
*
len
(
insts
),
dtype
=
"int64"
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
1
,
1
]).
astype
(
"float32"
)
trg_word
=
trg_word
.
reshape
(
-
1
,
1
,
1
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
,
1
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
,
1
)
def
to_lodtensor
(
data
,
place
,
lod
=
None
):
data_tensor
=
fluid
.
LoDTensor
()
data_tensor
.
set
(
data
,
place
)
if
lod
is
not
None
:
data_tensor
.
set_lod
(
lod
)
return
data_tensor
# beamsearch_op must use tensors with lod
init_score
=
to_lodtensor
(
np
.
zeros_like
(
trg_word
,
dtype
=
"float32"
).
reshape
(
-
1
,
1
),
place
,
[
range
(
trg_word
.
shape
[
0
]
+
1
)]
*
2
)
trg_word
=
to_lodtensor
(
trg_word
,
place
,
[
range
(
trg_word
.
shape
[
0
]
+
1
)]
*
2
)
init_idx
=
np
.
asarray
(
range
(
len
(
insts
)),
dtype
=
"int32"
)
data_input_dict
=
dict
(
zip
(
data_input_names
,
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
init_score
,
init_idx
,
trg_src_attn_bias
]))
return
data_input_dict
def
prepare_feed_dict_list
(
data_generator
,
count
,
place
):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list
=
[]
if
data_generator
is
not
None
:
# use_py_reader == False
data_input_names
=
encoder_data_input_fields
+
fast_decoder_data_input_fields
data
=
next
(
data_generator
)
for
idx
,
data_buffer
in
enumerate
(
data
):
data_input_dict
=
prepare_batch_input
(
data_buffer
,
data_input_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
bos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
,
place
)
feed_dict_list
.
append
(
data_input_dict
)
return
feed_dict_list
if
len
(
feed_dict_list
)
==
count
else
None
def
py_reader_provider_wrapper
(
data_reader
,
place
):
"""
Data provider needed by fluid.layers.py_reader.
"""
def
py_reader_provider
():
data_input_names
=
encoder_data_input_fields
+
fast_decoder_data_input_fields
for
batch_id
,
data
in
enumerate
(
data_reader
()):
data_input_dict
=
prepare_batch_input
(
data
,
data_input_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
bos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
,
place
)
yield
[
data_input_dict
[
item
]
for
item
in
data_input_names
]
return
py_reader_provider
def
fast_infer
(
args
):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
out_ids
,
out_scores
,
pyreader
=
fast_decoder
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
,
ModelHyperParams
.
max_length
+
1
,
ModelHyperParams
.
n_layer
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_key
,
ModelHyperParams
.
d_value
,
ModelHyperParams
.
d_model
,
ModelHyperParams
.
d_inner_hid
,
ModelHyperParams
.
prepostprocess_dropout
,
ModelHyperParams
.
attention_dropout
,
ModelHyperParams
.
relu_dropout
,
ModelHyperParams
.
preprocess_cmd
,
ModelHyperParams
.
postprocess_cmd
,
ModelHyperParams
.
weight_sharing
,
InferTaskConfig
.
beam_size
,
InferTaskConfig
.
max_out_len
,
ModelHyperParams
.
bos_idx
,
ModelHyperParams
.
eos_idx
,
use_py_reader
=
args
.
use_py_reader
)
# This is used here to set dropout to the test mode.
infer_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
if
args
.
use_mem_opt
:
fluid
.
memory_optimize
(
infer_program
)
if
InferTaskConfig
.
use_gpu
:
check_cuda
(
InferTaskConfig
.
use_gpu
)
place
=
fluid
.
CUDAPlace
(
0
)
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
fluid
.
io
.
load_vars
(
exe
,
InferTaskConfig
.
model_path
,
vars
=
[
var
for
var
in
infer_program
.
list_vars
()
if
isinstance
(
var
,
fluid
.
framework
.
Parameter
)
])
exec_strategy
=
fluid
.
ExecutionStrategy
()
# For faster executor
exec_strategy
.
use_experimental_executor
=
True
exec_strategy
.
num_threads
=
1
build_strategy
=
fluid
.
BuildStrategy
()
infer_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
TrainTaskConfig
.
use_gpu
,
main_program
=
infer_program
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
# data reader settings for inference
args
.
train_file_pattern
=
args
.
test_file_pattern
args
.
use_token_batch
=
False
args
.
sort_type
=
reader
.
SortType
.
NONE
args
.
shuffle
=
False
args
.
shuffle_batch
=
False
test_data
=
prepare_data_generator
(
args
,
is_test
=
False
,
count
=
dev_count
,
pyreader
=
pyreader
,
py_reader_provider_wrapper
=
py_reader_provider_wrapper
,
place
=
place
)
if
args
.
use_py_reader
:
pyreader
.
start
()
data_generator
=
None
else
:
data_generator
=
test_data
()
trg_idx2word
=
reader
.
DataReader
.
load_dict
(
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
while
True
:
try
:
feed_dict_list
=
prepare_feed_dict_list
(
data_generator
,
dev_count
,
place
)
if
args
.
use_parallel_exe
:
seq_ids
,
seq_scores
=
infer_exe
.
run
(
fetch_list
=
[
out_ids
.
name
,
out_scores
.
name
],
feed
=
feed_dict_list
,
return_numpy
=
False
)
else
:
seq_ids
,
seq_scores
=
exe
.
run
(
program
=
infer_program
,
fetch_list
=
[
out_ids
.
name
,
out_scores
.
name
],
feed
=
feed_dict_list
[
0
]
if
feed_dict_list
is
not
None
else
None
,
return_numpy
=
False
,
use_program_cache
=
False
)
seq_ids_list
,
seq_scores_list
=
[
seq_ids
],
[
seq_scores
]
if
isinstance
(
seq_ids
,
paddle
.
fluid
.
LoDTensor
)
else
(
seq_ids
,
seq_scores
)
for
seq_ids
,
seq_scores
in
zip
(
seq_ids_list
,
seq_scores_list
):
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps
=
[[]
for
i
in
range
(
len
(
seq_ids
.
lod
()[
0
])
-
1
)]
scores
=
[[]
for
i
in
range
(
len
(
seq_scores
.
lod
()[
0
])
-
1
)]
for
i
in
range
(
len
(
seq_ids
.
lod
()[
0
])
-
1
):
# for each source sentence
start
=
seq_ids
.
lod
()[
0
][
i
]
end
=
seq_ids
.
lod
()[
0
][
i
+
1
]
for
j
in
range
(
end
-
start
):
# for each candidate
sub_start
=
seq_ids
.
lod
()[
1
][
start
+
j
]
sub_end
=
seq_ids
.
lod
()[
1
][
start
+
j
+
1
]
hyps
[
i
].
append
(
" "
.
join
([
trg_idx2word
[
idx
]
for
idx
in
post_process_seq
(
np
.
array
(
seq_ids
)[
sub_start
:
sub_end
])
]))
scores
[
i
].
append
(
np
.
array
(
seq_scores
)[
sub_end
-
1
])
print
(
hyps
[
i
][
-
1
])
if
len
(
hyps
[
i
])
>=
InferTaskConfig
.
n_best
:
break
except
(
StopIteration
,
fluid
.
core
.
EOFException
):
# The data pass is over.
if
args
.
use_py_reader
:
pyreader
.
reset
()
break
if
__name__
==
"__main__"
:
args
=
parse_args
()
fast_infer
(
args
)
PaddleNLP/neural_machine_translation/transformer/inference_model.py
0 → 100644
浏览文件 @
80628bc6
#encoding=utf8
import
logging
import
os
import
six
import
sys
import
time
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
#include palm for easier nlp coding
from
palm.toolkit.input_field
import
InputField
from
palm.toolkit.configure
import
PDConfig
# include task-specific libs
import
desc
import
reader
from
transformer
import
create_net
def
init_from_pretrain_model
(
args
,
exe
,
program
):
assert
isinstance
(
args
.
init_from_pretrain_model
,
str
)
if
not
os
.
path
.
exists
(
args
.
init_from_pretrain_model
):
raise
Warning
(
"The pretrained params do not exist."
)
return
False
def
existed_params
(
var
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
init_from_pretrain_model
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
args
.
init_from_pretrain_model
,
main_program
=
program
,
predicate
=
existed_params
)
print
(
"finish initing model from pretrained params from %s"
%
(
args
.
init_from_pretrain_model
))
return
True
def
init_from_params
(
args
,
exe
,
program
):
assert
isinstance
(
args
.
init_from_params
,
str
)
if
not
os
.
path
.
exists
(
args
.
init_from_params
):
raise
Warning
(
"the params path does not exist."
)
return
False
fluid
.
io
.
load_params
(
executor
=
exe
,
dirname
=
args
.
init_from_params
,
main_program
=
program
,
filename
=
"params.pdparams"
)
print
(
"finish init model from params from %s"
%
(
args
.
init_from_params
))
return
True
def
do_save_inference_model
(
args
):
if
args
.
use_cuda
:
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
place
=
fluid
.
CUDAPlace
(
0
)
else
:
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
place
=
fluid
.
CPUPlace
()
test_prog
=
fluid
.
default_main_program
()
startup_prog
=
fluid
.
default_startup_program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
# define input and reader
input_field_names
=
desc
.
encoder_data_input_fields
+
desc
.
fast_decoder_data_input_fields
input_slots
=
[{
"name"
:
name
,
"shape"
:
desc
.
input_descs
[
name
][
0
],
"dtype"
:
desc
.
input_descs
[
name
][
1
]
}
for
name
in
input_field_names
]
input_field
=
InputField
(
input_slots
)
input_field
.
build
(
build_pyreader
=
True
)
# define the network
predictions
=
create_net
(
is_training
=
False
,
model_input
=
input_field
,
args
=
args
)
out_ids
,
out_scores
=
predictions
# This is used here to set dropout to the test mode.
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
# prepare predicting
## define the executor and program for training
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
assert
(
args
.
init_from_params
)
or
(
args
.
init_from_pretrain_model
)
if
args
.
init_from_params
:
init_from_params
(
args
,
exe
,
test_prog
)
elif
args
.
init_from_pretrain_model
:
init_from_pretrain_model
(
args
,
exe
,
test_prog
)
# saving inference model
fluid
.
io
.
save_inference_model
(
args
.
inference_model_dir
,
feeded_var_names
=
input_field_names
,
target_vars
=
[
out_ids
,
out_scores
],
executor
=
exe
,
main_program
=
test_prog
,
model_filename
=
"model.pdmodel"
,
params_filename
=
"params.pdparams"
)
print
(
"save inference model at %s"
%
(
args
.
inference_model_dir
))
if
__name__
==
"__main__"
:
args
=
PDConfig
(
yaml_file
=
"./transformer.yaml"
)
args
.
build
()
args
.
Print
()
do_save_inference_model
(
args
)
PaddleNLP/neural_machine_translation/transformer/main.py
0 → 100644
浏览文件 @
80628bc6
#encoding=utf8
import
os
import
sys
import
logging
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
#include palm for easier nlp coding
from
palm.toolkit.configure
import
PDConfig
from
train
import
do_train
from
predict
import
do_predict
from
inference_model
import
do_save_inference_model
if
__name__
==
"__main__"
:
LOG_FORMAT
=
"[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
DEBUG
,
format
=
LOG_FORMAT
)
logging
.
getLogger
().
setLevel
(
logging
.
INFO
)
args
=
PDConfig
(
yaml_file
=
"./transformer.yaml"
)
args
.
build
()
args
.
Print
()
if
args
.
do_train
:
do_train
(
args
)
if
args
.
do_predict
:
do_predict
(
args
)
if
args
.
do_save_inference_model
:
do_save_inference_model
(
args
)
\ No newline at end of file
PaddleNLP/neural_machine_translation/transformer/palm/__init__.py
0 → 100644
浏览文件 @
80628bc6
PaddleNLP/neural_machine_translation/transformer/palm/toolkit/__init__.py
0 → 100644
浏览文件 @
80628bc6
PaddleNLP/neural_machine_translation/transformer/palm/toolkit/configure.py
0 → 100644
浏览文件 @
80628bc6
#encoding=utf8
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
import
argparse
import
json
import
yaml
import
six
import
logging
logging_only_message
=
"%(message)s"
logging_details
=
"%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"
class
JsonConfig
(
object
):
"""
A high-level api for handling json configure file.
"""
def
__init__
(
self
,
config_path
):
self
.
_config_dict
=
self
.
_parse
(
config_path
)
def
_parse
(
self
,
config_path
):
try
:
with
open
(
config_path
)
as
json_file
:
config_dict
=
json
.
load
(
json_file
)
except
:
raise
IOError
(
"Error in parsing bert model config file '%s'"
%
config_path
)
else
:
return
config_dict
def
__getitem__
(
self
,
key
):
return
self
.
_config_dict
[
key
]
def
print_config
(
self
):
for
arg
,
value
in
sorted
(
six
.
iteritems
(
self
.
_config_dict
)):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
class
ArgumentGroup
(
object
):
def
__init__
(
self
,
parser
,
title
,
des
):
self
.
_group
=
parser
.
add_argument_group
(
title
=
title
,
description
=
des
)
def
add_arg
(
self
,
name
,
type
,
default
,
help
,
**
kwargs
):
type
=
str2bool
if
type
==
bool
else
type
self
.
_group
.
add_argument
(
"--"
+
name
,
default
=
default
,
type
=
type
,
help
=
help
+
' Default: %(default)s.'
,
**
kwargs
)
class
ArgConfig
(
object
):
"""
A high-level api for handling argument configs.
"""
def
__init__
(
self
):
parser
=
argparse
.
ArgumentParser
()
train_g
=
ArgumentGroup
(
parser
,
"training"
,
"training options."
)
train_g
.
add_arg
(
"epoch"
,
int
,
3
,
"Number of epoches for fine-tuning."
)
train_g
.
add_arg
(
"learning_rate"
,
float
,
5e-5
,
"Learning rate used to train with warmup."
)
train_g
.
add_arg
(
"lr_scheduler"
,
str
,
"linear_warmup_decay"
,
"scheduler of learning rate."
,
choices
=
[
'linear_warmup_decay'
,
'noam_decay'
])
train_g
.
add_arg
(
"weight_decay"
,
float
,
0.01
,
"Weight decay rate for L2 regularizer."
)
train_g
.
add_arg
(
"warmup_proportion"
,
float
,
0.1
,
"Proportion of training steps to perform linear learning rate warmup for."
)
train_g
.
add_arg
(
"save_steps"
,
int
,
1000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"use_fp16"
,
bool
,
False
,
"Whether to use fp16 mixed precision training."
)
train_g
.
add_arg
(
"loss_scaling"
,
float
,
1.0
,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g
.
add_arg
(
"pred_dir"
,
str
,
None
,
"Path to save the prediction results"
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related."
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
log_g
.
add_arg
(
"verbose"
,
bool
,
False
,
"Whether to output verbose log."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
True
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"use_fast_executor"
,
bool
,
False
,
"If set, use fast parallel executor (in experiment)."
)
run_type_g
.
add_arg
(
"num_iteration_per_drop_scope"
,
int
,
1
,
"Ihe iteration intervals to clean up temporary variables."
)
run_type_g
.
add_arg
(
"do_train"
,
bool
,
True
,
"Whether to perform training."
)
run_type_g
.
add_arg
(
"do_predict"
,
bool
,
True
,
"Whether to perform prediction."
)
custom_g
=
ArgumentGroup
(
parser
,
"customize"
,
"customized options."
)
self
.
custom_g
=
custom_g
self
.
parser
=
parser
def
add_arg
(
self
,
name
,
dtype
,
default
,
descrip
):
self
.
custom_g
.
add_arg
(
name
,
dtype
,
default
,
descrip
)
def
build_conf
(
self
):
return
self
.
parser
.
parse_args
()
def
str2bool
(
v
):
# because argparse does not support to parse "true, False" as python
# boolean directly
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
def
print_arguments
(
args
,
log
=
None
):
if
not
log
:
print
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
else
:
log
.
info
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
log
.
info
(
'%s: %s'
%
(
arg
,
value
))
log
.
info
(
'------------------------------------------------'
)
class
PDConfig
(
object
):
"""
A high-level API for managing configuration files in PaddlePaddle.
Can jointly work with command-line-arugment, json files and yaml files.
"""
def
__init__
(
self
,
json_file
=
""
,
yaml_file
=
""
,
fuse_args
=
True
):
"""
Init funciton for PDConfig.
json_file: the path to the json configure file.
yaml_file: the path to the yaml configure file.
fuse_args: if fuse the json/yaml configs with argparse.
"""
assert
isinstance
(
json_file
,
str
)
assert
isinstance
(
yaml_file
,
str
)
if
json_file
!=
""
and
yaml_file
!=
""
:
raise
Warning
(
"json_file and yaml_file can not co-exist for now. please only use one configure file type."
)
return
self
.
args
=
None
self
.
arg_config
=
{}
self
.
json_config
=
{}
self
.
yaml_config
=
{}
parser
=
argparse
.
ArgumentParser
()
self
.
default_g
=
ArgumentGroup
(
parser
,
"default"
,
"default options."
)
self
.
yaml_g
=
ArgumentGroup
(
parser
,
"yaml"
,
"options from yaml."
)
self
.
json_g
=
ArgumentGroup
(
parser
,
"json"
,
"options from json."
)
self
.
com_g
=
ArgumentGroup
(
parser
,
"custom"
,
"customized options."
)
self
.
default_g
.
add_arg
(
"do_train"
,
bool
,
False
,
"Whether to perform training."
)
self
.
default_g
.
add_arg
(
"do_predict"
,
bool
,
False
,
"Whether to perform predicting."
)
self
.
default_g
.
add_arg
(
"do_eval"
,
bool
,
False
,
"Whether to perform evaluating."
)
self
.
default_g
.
add_arg
(
"do_save_inference_model"
,
bool
,
False
,
"Whether to perform model saving for inference."
)
self
.
parser
=
parser
if
json_file
!=
""
:
self
.
load_json
(
json_file
,
fuse_args
=
fuse_args
)
if
yaml_file
:
self
.
load_yaml
(
yaml_file
,
fuse_args
=
fuse_args
)
def
load_json
(
self
,
file_path
,
fuse_args
=
True
):
if
not
os
.
path
.
exists
(
file_path
):
raise
Warning
(
"the json file %s does not exist."
%
file_path
)
return
with
open
(
file_path
,
"r"
)
as
fin
:
self
.
json_config
=
json
.
loads
(
fin
.
read
())
fin
.
close
()
if
fuse_args
:
for
name
in
self
.
json_config
:
if
isinstance
(
self
.
json_config
[
name
],
list
):
self
.
json_g
.
add_arg
(
name
,
type
(
self
.
json_config
[
name
][
0
]),
self
.
json_config
[
name
],
"This is from %s"
%
file_path
,
nargs
=
len
(
self
.
json_config
[
name
]))
continue
if
not
isinstance
(
self
.
json_config
[
name
],
int
)
\
and
not
isinstance
(
self
.
json_config
[
name
],
float
)
\
and
not
isinstance
(
self
.
json_config
[
name
],
str
)
\
and
not
isinstance
(
self
.
json_config
[
name
],
bool
):
continue
self
.
json_g
.
add_arg
(
name
,
type
(
self
.
json_config
[
name
]),
self
.
json_config
[
name
],
"This is from %s"
%
file_path
)
def
load_yaml
(
self
,
file_path
,
fuse_args
=
True
):
if
not
os
.
path
.
exists
(
file_path
):
raise
Warning
(
"the yaml file %s does not exist."
%
file_path
)
return
with
open
(
file_path
,
"r"
)
as
fin
:
self
.
yaml_config
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
SafeLoader
)
fin
.
close
()
if
fuse_args
:
for
name
in
self
.
yaml_config
:
if
isinstance
(
self
.
yaml_config
[
name
],
list
):
self
.
yaml_g
.
add_arg
(
name
,
type
(
self
.
yaml_config
[
name
][
0
]),
self
.
yaml_config
[
name
],
"This is from %s"
%
file_path
,
nargs
=
len
(
self
.
yaml_config
[
name
]))
continue
if
not
isinstance
(
self
.
yaml_config
[
name
],
int
)
\
and
not
isinstance
(
self
.
yaml_config
[
name
],
float
)
\
and
not
isinstance
(
self
.
yaml_config
[
name
],
str
)
\
and
not
isinstance
(
self
.
yaml_config
[
name
],
bool
):
continue
self
.
yaml_g
.
add_arg
(
name
,
type
(
self
.
yaml_config
[
name
]),
self
.
yaml_config
[
name
],
"This is from %s"
%
file_path
)
def
build
(
self
):
self
.
args
=
self
.
parser
.
parse_args
()
self
.
arg_config
=
vars
(
self
.
args
)
def
__add__
(
self
,
new_arg
):
assert
isinstance
(
new_arg
,
list
)
or
isinstance
(
new_arg
,
tuple
)
assert
len
(
new_arg
)
>=
3
assert
self
.
args
is
None
name
=
new_arg
[
0
]
dtype
=
new_arg
[
1
]
dvalue
=
new_arg
[
2
]
desc
=
new_arg
[
3
]
if
len
(
new_arg
)
==
4
else
"Description is not provided."
self
.
com_g
.
add_arg
(
name
,
dtype
,
dvalue
,
desc
)
return
self
def
__getattr__
(
self
,
name
):
if
name
in
self
.
arg_config
:
return
self
.
arg_config
[
name
]
if
name
in
self
.
json_config
:
return
self
.
json_config
[
name
]
if
name
in
self
.
yaml_config
:
return
self
.
yaml_config
[
name
]
raise
Warning
(
"The argument %s is not defined."
%
name
)
def
Print
(
self
):
print
(
"-"
*
70
)
for
name
in
self
.
arg_config
:
print
(
"%s:
\t\t\t\t
%s"
%
(
str
(
name
),
str
(
self
.
arg_config
[
name
])))
for
name
in
self
.
json_config
:
if
name
not
in
self
.
arg_config
:
print
(
"%s:
\t\t\t\t
%s"
%
(
str
(
name
),
str
(
self
.
json_config
[
name
])))
for
name
in
self
.
yaml_config
:
if
name
not
in
self
.
arg_config
:
print
(
"%s:
\t\t\t\t
%s"
%
(
str
(
name
),
str
(
self
.
yaml_config
[
name
])))
print
(
"-"
*
70
)
if
__name__
==
"__main__"
:
"""
pd_config = PDConfig(json_file = "./test/bert_config.json")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
pd_config = PDConfig(yaml_file = "./test/bert_config.yaml")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
"""
pd_config
=
PDConfig
(
yaml_file
=
"./test/bert_config.yaml"
)
pd_config
+=
(
"my_age"
,
int
,
18
,
"I am forever 18."
)
pd_config
.
build
()
print
(
pd_config
.
do_train
)
print
(
pd_config
.
hidden_size
)
print
(
pd_config
.
my_age
)
PaddleNLP/neural_machine_translation/transformer/palm/toolkit/input_field.py
0 → 100644
浏览文件 @
80628bc6
#encoding=utf8
from
__future__
import
print_function
from
__future__
import
division
from
__future__
import
print_function
import
os
import
six
import
ast
import
copy
import
numpy
as
np
import
paddle.fluid
as
fluid
class
Placeholder
(
object
):
def
__init__
(
self
):
self
.
shapes
=
[]
self
.
dtypes
=
[]
self
.
lod_levels
=
[]
self
.
names
=
[]
def
__init__
(
self
,
input_shapes
):
self
.
shapes
=
[]
self
.
dtypes
=
[]
self
.
lod_levels
=
[]
self
.
names
=
[]
for
new_holder
in
input_shapes
:
shape
=
new_holder
[
0
]
dtype
=
new_holder
[
1
]
lod_level
=
new_holder
[
2
]
if
len
(
new_holder
)
>=
3
else
0
name
=
new_holder
[
3
]
if
len
(
new_holder
)
>=
4
else
""
self
.
append_placeholder
(
shape
,
dtype
,
lod_level
=
lod_level
,
name
=
name
)
def
append_placeholder
(
self
,
shape
,
dtype
,
lod_level
=
0
,
name
=
""
):
self
.
shapes
.
append
(
shape
)
self
.
dtypes
.
append
(
dtype
)
self
.
lod_levels
.
append
(
lod_level
)
self
.
names
.
append
(
name
)
def
build
(
self
,
capacity
,
reader_name
,
use_double_buffer
=
False
):
pyreader
=
fluid
.
layers
.
py_reader
(
capacity
=
capacity
,
shapes
=
self
.
shapes
,
dtypes
=
self
.
dtypes
,
lod_levels
=
self
.
lod_levels
,
name
=
reader_name
,
use_double_buffer
=
use_double_buffer
)
return
[
pyreader
,
fluid
.
layers
.
read_file
(
pyreader
)]
def
__add__
(
self
,
new_holder
):
assert
isinstance
(
new_holder
,
tuple
)
or
isinstance
(
new_holder
,
list
)
assert
len
(
new_holder
)
>=
2
shape
=
new_holder
[
0
]
dtype
=
new_holder
[
1
]
lod_level
=
new_holder
[
2
]
if
len
(
new_holder
)
>=
3
else
0
name
=
new_holder
[
3
]
if
len
(
new_holder
)
>=
4
else
""
self
.
append_placeholder
(
shape
,
dtype
,
lod_level
=
lod_level
,
name
=
name
)
class
InputField
(
object
):
"""
A high-level API for handling inputs in PaddlePaddle.
"""
def
__init__
(
self
,
input_slots
=
[]):
self
.
shapes
=
[]
self
.
dtypes
=
[]
self
.
names
=
[]
self
.
lod_levels
=
[]
self
.
input_slots
=
{}
self
.
feed_list_str
=
[]
self
.
feed_list
=
[]
self
.
reader
=
None
if
input_slots
:
for
input_slot
in
input_slots
:
self
+=
input_slot
def
__add__
(
self
,
input_slot
):
if
isinstance
(
input_slot
,
list
)
or
isinstance
(
input_slot
,
tuple
):
name
=
input_slot
[
0
]
shape
=
input_slot
[
1
]
dtype
=
input_slot
[
2
]
lod_level
=
input_slot
[
3
]
if
len
(
input_slot
)
==
4
else
0
if
isinstance
(
input_slot
,
dict
):
name
=
input_slot
[
"name"
]
shape
=
input_slot
[
"shape"
]
dtype
=
input_slot
[
"dtype"
]
lod_level
=
input_slot
[
"lod_level"
]
if
"lod_level"
in
input_slot
else
0
self
.
shapes
.
append
(
shape
)
self
.
dtypes
.
append
(
dtype
)
self
.
names
.
append
(
name
)
self
.
lod_levels
.
append
(
lod_level
)
self
.
feed_list_str
.
append
(
name
)
return
self
def
__getattr__
(
self
,
name
):
if
name
not
in
self
.
input_slots
:
raise
Warning
(
"the attr %s has not been defined yet."
%
name
)
return
None
return
self
.
input_slots
[
name
]
def
build
(
self
,
build_pyreader
=
False
,
capacity
=
100
,
iterable
=
False
):
for
_name
,
_shape
,
_dtype
,
_lod_level
in
zip
(
self
.
names
,
self
.
shapes
,
self
.
dtypes
,
self
.
lod_levels
):
self
.
input_slots
[
_name
]
=
fluid
.
layers
.
data
(
name
=
_name
,
shape
=
_shape
,
dtype
=
_dtype
,
lod_level
=
_lod_level
)
for
name
in
self
.
feed_list_str
:
self
.
feed_list
.
append
(
self
.
input_slots
[
name
])
if
build_pyreader
:
self
.
reader
=
fluid
.
io
.
PyReader
(
feed_list
=
self
.
feed_list
,
capacity
=
capacity
,
iterable
=
iterable
)
def
start
(
self
,
generator
=
None
):
if
generator
is
not
None
:
self
.
reader
.
decorate_batch_generator
(
generator
)
self
.
reader
.
start
()
if
__name__
==
"__main__"
:
mnist_input_slots
=
[{
"name"
:
"image"
,
"shape"
:
(
-
1
,
32
,
32
,
1
),
"dtype"
:
"int32"
},
{
"name"
:
"label"
,
"shape"
:
[
-
1
,
1
],
"dtype"
:
"int64"
}]
input_field
=
InputField
(
mnist_input_slots
)
input_field
+=
{
"name"
:
"large_image"
,
"shape"
:
(
-
1
,
64
,
64
,
1
),
"dtype"
:
"int32"
}
input_field
+=
{
"name"
:
"large_color_image"
,
"shape"
:
(
-
1
,
64
,
64
,
3
),
"dtype"
:
"int32"
}
input_field
.
build
()
print
(
input_field
.
feed_list
)
print
(
input_field
.
image
)
print
(
input_field
.
large_color_image
)
PaddleNLP/neural_machine_translation/transformer/predict.py
0 → 100644
浏览文件 @
80628bc6
#encoding=utf8
import
logging
import
os
import
six
import
sys
import
time
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
#include palm for easier nlp coding
from
palm.toolkit.input_field
import
InputField
from
palm.toolkit.configure
import
PDConfig
# include task-specific libs
import
desc
import
reader
from
transformer
import
create_net
,
position_encoding_init
def
init_from_pretrain_model
(
args
,
exe
,
program
):
assert
isinstance
(
args
.
init_from_pretrain_model
,
str
)
if
not
os
.
path
.
exists
(
args
.
init_from_pretrain_model
):
raise
Warning
(
"The pretrained params do not exist."
)
return
False
def
existed_params
(
var
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
init_from_pretrain_model
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
args
.
init_from_pretrain_model
,
main_program
=
program
,
predicate
=
existed_params
)
print
(
"finish initing model from pretrained params from %s"
%
(
args
.
init_from_pretrain_model
))
return
True
def
init_from_params
(
args
,
exe
,
program
):
assert
isinstance
(
args
.
init_from_params
,
str
)
if
not
os
.
path
.
exists
(
args
.
init_from_params
):
raise
Warning
(
"the params path does not exist."
)
return
False
fluid
.
io
.
load_params
(
executor
=
exe
,
dirname
=
args
.
init_from_params
,
main_program
=
program
,
filename
=
"params.pdparams"
)
print
(
"finish init model from params from %s"
%
(
args
.
init_from_params
))
return
True
def
post_process_seq
(
seq
,
bos_idx
,
eos_idx
,
output_bos
=
False
,
output_eos
=
False
):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos
=
len
(
seq
)
-
1
for
i
,
idx
in
enumerate
(
seq
):
if
idx
==
eos_idx
:
eos_pos
=
i
break
seq
=
[
idx
for
idx
in
seq
[:
eos_pos
+
1
]
if
(
output_bos
or
idx
!=
bos_idx
)
and
(
output_eos
or
idx
!=
eos_idx
)
]
return
seq
def
do_predict
(
args
):
if
args
.
use_cuda
:
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
place
=
fluid
.
CUDAPlace
(
0
)
else
:
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
place
=
fluid
.
CPUPlace
()
# define the data generator
processor
=
reader
.
DataProcessor
(
fpattern
=
args
.
predict_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
use_token_batch
=
False
,
batch_size
=
args
.
batch_size
,
device_count
=
dev_count
,
pool_size
=
args
.
pool_size
,
sort_type
=
reader
.
SortType
.
NONE
,
shuffle
=
False
,
shuffle_batch
=
False
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
max_length
=
args
.
max_length
,
n_head
=
args
.
n_head
)
batch_generator
=
processor
.
data_generator
(
phase
=
"predict"
,
place
=
place
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
processor
.
get_vocab_summary
()
trg_idx2word
=
reader
.
DataProcessor
.
load_dict
(
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
test_prog
=
fluid
.
default_main_program
()
startup_prog
=
fluid
.
default_startup_program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
# define input and reader
input_field_names
=
desc
.
encoder_data_input_fields
+
desc
.
fast_decoder_data_input_fields
input_slots
=
[{
"name"
:
name
,
"shape"
:
desc
.
input_descs
[
name
][
0
],
"dtype"
:
desc
.
input_descs
[
name
][
1
]
}
for
name
in
input_field_names
]
input_field
=
InputField
(
input_slots
)
input_field
.
build
(
build_pyreader
=
True
)
# define the network
predictions
=
create_net
(
is_training
=
False
,
model_input
=
input_field
,
args
=
args
)
out_ids
,
out_scores
=
predictions
out_ids
.
persistable
=
out_scores
.
persistable
=
True
# This is used here to set dropout to the test mode.
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
# prepare predicting
## define the executor and program for training
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
assert
(
args
.
init_from_params
)
or
(
args
.
init_from_pretrain_model
)
if
args
.
init_from_params
:
init_from_params
(
args
,
exe
,
test_prog
)
elif
args
.
init_from_pretrain_model
:
init_from_pretrain_model
(
args
,
exe
,
test_prog
)
# to avoid a longer length than training, reset the size of position encoding to max_length
for
pos_enc_param_name
in
desc
.
pos_enc_param_names
:
pos_enc_param
=
fluid
.
global_scope
().
find_var
(
pos_enc_param_name
).
get_tensor
()
pos_enc_param
.
set
(
position_encoding_init
(
args
.
max_length
+
1
,
args
.
d_model
),
place
)
compiled_test_prog
=
fluid
.
CompiledProgram
(
test_prog
)
f
=
open
(
args
.
output_file
,
"wb"
)
# start predicting
## decorate the pyreader with batch_generator
input_field
.
reader
.
decorate_batch_generator
(
batch_generator
)
input_field
.
reader
.
start
()
while
True
:
try
:
seq_ids
,
seq_scores
=
exe
.
run
(
compiled_test_prog
,
fetch_list
=
[
out_ids
.
name
,
out_scores
.
name
],
return_numpy
=
False
)
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps
=
[[]
for
i
in
range
(
len
(
seq_ids
.
lod
()[
0
])
-
1
)]
scores
=
[[]
for
i
in
range
(
len
(
seq_scores
.
lod
()[
0
])
-
1
)]
for
i
in
range
(
len
(
seq_ids
.
lod
()[
0
])
-
1
):
# for each source sentence
start
=
seq_ids
.
lod
()[
0
][
i
]
end
=
seq_ids
.
lod
()[
0
][
i
+
1
]
for
j
in
range
(
end
-
start
):
# for each candidate
sub_start
=
seq_ids
.
lod
()[
1
][
start
+
j
]
sub_end
=
seq_ids
.
lod
()[
1
][
start
+
j
+
1
]
hyps
[
i
].
append
(
b
" "
.
join
([
trg_idx2word
[
idx
]
for
idx
in
post_process_seq
(
np
.
array
(
seq_ids
)[
sub_start
:
sub_end
],
args
.
bos_idx
,
args
.
eos_idx
)
]))
scores
[
i
].
append
(
np
.
array
(
seq_scores
)[
sub_end
-
1
])
f
.
write
(
hyps
[
i
][
-
1
]
+
b
"
\n
"
)
if
len
(
hyps
[
i
])
>=
args
.
n_best
:
break
except
fluid
.
core
.
EOFException
:
break
f
.
close
()
if
__name__
==
"__main__"
:
args
=
PDConfig
(
yaml_file
=
"./transformer.yaml"
)
args
.
build
()
args
.
Print
()
do_predict
(
args
)
PaddleNLP/neural_machine_translation/transformer/reader.py
浏览文件 @
80628bc6
...
@@ -4,6 +4,134 @@ import os
...
@@ -4,6 +4,134 @@ import os
import
tarfile
import
tarfile
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
def
pad_batch_data
(
insts
,
pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
False
,
return_attn_bias
=
True
,
return_max_len
=
True
,
return_num_token
=
False
):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list
=
[]
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data
=
np
.
array
(
[
inst
+
[
pad_idx
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_data
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
is_label
:
# label weight
inst_weight
=
np
.
array
([[
1.
]
*
len
(
inst
)
+
[
0.
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_weight
.
astype
(
"float32"
).
reshape
([
-
1
,
1
])]
else
:
# position data
inst_pos
=
np
.
array
([
list
(
range
(
0
,
len
(
inst
)))
+
[
0
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_pos
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
return_attn_bias
:
if
is_target
:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data
=
np
.
ones
((
inst_data
.
shape
[
0
],
max_len
,
max_len
))
slf_attn_bias_data
=
np
.
triu
(
slf_attn_bias_data
,
1
).
reshape
([
-
1
,
1
,
max_len
,
max_len
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
,
[
1
,
n_head
,
1
,
1
])
*
[
-
1e9
]
else
:
# This is used to avoid attention on paddings.
slf_attn_bias_data
=
np
.
array
([[
0
]
*
len
(
inst
)
+
[
-
1e9
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
.
reshape
([
-
1
,
1
,
1
,
max_len
]),
[
1
,
n_head
,
max_len
,
1
])
return_list
+=
[
slf_attn_bias_data
.
astype
(
"float32"
)]
if
return_max_len
:
return_list
+=
[
max_len
]
if
return_num_token
:
num_token
=
0
for
inst
in
insts
:
num_token
+=
len
(
inst
)
return_list
+=
[
num_token
]
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
def
prepare_train_input
(
insts
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
"""
Put all padded data needed by training into a list.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
,
1
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
,
1
)
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
pad_batch_data
(
[
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
trg_word
=
trg_word
.
reshape
(
-
1
,
trg_max_len
,
1
)
trg_pos
=
trg_pos
.
reshape
(
-
1
,
trg_max_len
,
1
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
[
inst
[
2
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
True
,
return_attn_bias
=
False
,
return_max_len
=
False
,
return_num_token
=
True
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
]
return
data_inputs
def
prepare_infer_input
(
insts
,
src_pad_idx
,
bos_idx
,
n_head
,
place
):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
# start tokens
trg_word
=
np
.
asarray
([[
bos_idx
]]
*
len
(
insts
),
dtype
=
"int64"
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
1
,
1
]).
astype
(
"float32"
)
trg_word
=
trg_word
.
reshape
(
-
1
,
1
,
1
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
,
1
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
,
1
)
def
to_lodtensor
(
data
,
place
,
lod
=
None
):
data_tensor
=
fluid
.
LoDTensor
()
data_tensor
.
set
(
data
,
place
)
if
lod
is
not
None
:
data_tensor
.
set_lod
(
lod
)
return
data_tensor
# beamsearch_op must use tensors with lod
init_score
=
to_lodtensor
(
np
.
zeros_like
(
trg_word
,
dtype
=
"float32"
).
reshape
(
-
1
,
1
),
place
,
[
range
(
trg_word
.
shape
[
0
]
+
1
)]
*
2
)
trg_word
=
to_lodtensor
(
trg_word
,
place
,
[
range
(
trg_word
.
shape
[
0
]
+
1
)]
*
2
)
init_idx
=
np
.
asarray
(
range
(
len
(
insts
)),
dtype
=
"int32"
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
init_score
,
init_idx
,
trg_src_attn_bias
]
return
data_inputs
class
SortType
(
object
):
class
SortType
(
object
):
...
@@ -95,7 +223,7 @@ class MinMaxFilter(object):
...
@@ -95,7 +223,7 @@ class MinMaxFilter(object):
return
self
.
_creator
.
batch
return
self
.
_creator
.
batch
class
Data
Reade
r
(
object
):
class
Data
Processo
r
(
object
):
"""
"""
The data reader loads all data from files and produces batches of data
The data reader loads all data from files and produces batches of data
in the way corresponding to settings.
in the way corresponding to settings.
...
@@ -104,12 +232,14 @@ class DataReader(object):
...
@@ -104,12 +232,14 @@ class DataReader(object):
is shuffled in each pass and sorted in each pool:
is shuffled in each pass and sorted in each pool:
```
```
train_data = Data
Reade
r(
train_data = Data
Processo
r(
src_vocab_fpath='data/src_vocab_file',
src_vocab_fpath='data/src_vocab_file',
trg_vocab_fpath='data/trg_vocab_file',
trg_vocab_fpath='data/trg_vocab_file',
fpattern='data/part-*',
fpattern='data/part-*',
use_token_batch=True,
use_token_batch=True,
batch_size=2000,
batch_size=2000,
device_count=8,
n_head=8,
pool_size=10000,
pool_size=10000,
sort_type=SortType.POOL,
sort_type=SortType.POOL,
shuffle=True,
shuffle=True,
...
@@ -117,7 +247,7 @@ class DataReader(object):
...
@@ -117,7 +247,7 @@ class DataReader(object):
start_mark='<s>',
start_mark='<s>',
end_mark='<e>',
end_mark='<e>',
unk_mark='<unk>',
unk_mark='<unk>',
clip_last_batch=False).
batch_generator
clip_last_batch=False).
data_generator(phase='train')
```
```
:param src_vocab_fpath: The path of vocabulary file of source language.
:param src_vocab_fpath: The path of vocabulary file of source language.
...
@@ -131,6 +261,12 @@ class DataReader(object):
...
@@ -131,6 +261,12 @@ class DataReader(object):
mini-batch.
mini-batch.
:type batch_size: int
:type batch_size: int
:param pool_size: The size of pool buffer.
:param pool_size: The size of pool buffer.
:type device_count: int
:param device_count: The number of devices. The actual batch size is
determined by both batch_size and device_count.
:type n_head: int
:param n_head: The number of head used in multi-head attention. Actually,
this is not a reader related argument, but is used for input data.
:type pool_size: int
:type pool_size: int
:param sort_type: The grain to sort by length: 'global' for all
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
instances; 'pool' for instances in pool; 'none' for no sort.
...
@@ -164,6 +300,9 @@ class DataReader(object):
...
@@ -164,6 +300,9 @@ class DataReader(object):
:type end_mark: basestring
:type end_mark: basestring
:param unk_mark: The token representing for unknown word in dictionary.
:param unk_mark: The token representing for unknown word in dictionary.
:type unk_mark: basestring
:type unk_mark: basestring
:param only_src: Whether each line is a source and target sentence
pair or only has the source sentence.
:type only_src: bool
:param seed: The seed for random.
:param seed: The seed for random.
:type seed: int
:type seed: int
"""
"""
...
@@ -173,14 +312,15 @@ class DataReader(object):
...
@@ -173,14 +312,15 @@ class DataReader(object):
trg_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
fpattern
,
batch_size
,
batch_size
,
device_count
,
n_head
,
pool_size
,
pool_size
,
sort_type
=
SortType
.
GLOBAL
,
sort_type
=
SortType
.
GLOBAL
,
clip_last_batch
=
Tru
e
,
clip_last_batch
=
Fals
e
,
tar_fname
=
None
,
tar_fname
=
None
,
min_length
=
0
,
min_length
=
0
,
max_length
=
100
,
max_length
=
100
,
shuffle
=
True
,
shuffle
=
True
,
shuffle_seed
=
None
,
shuffle_batch
=
False
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
use_token_batch
=
False
,
field_delimiter
=
"
\t
"
,
field_delimiter
=
"
\t
"
,
...
@@ -188,37 +328,44 @@ class DataReader(object):
...
@@ -188,37 +328,44 @@ class DataReader(object):
start_mark
=
"<s>"
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
unk_mark
=
"<unk>"
,
only_src
=
False
,
seed
=
0
):
seed
=
0
):
# convert str to bytes, and use byte data
field_delimiter
=
field_delimiter
.
encode
(
"utf8"
)
token_delimiter
=
token_delimiter
.
encode
(
"utf8"
)
start_mark
=
start_mark
.
encode
(
"utf8"
)
end_mark
=
end_mark
.
encode
(
"utf8"
)
unk_mark
=
unk_mark
.
encode
(
"utf8"
)
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_only_src
=
True
if
trg_vocab_fpath
is
not
None
:
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_only_src
=
False
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_only_src
=
only_src
self
.
_pool_size
=
pool_size
self
.
_pool_size
=
pool_size
self
.
_batch_size
=
batch_size
self
.
_batch_size
=
batch_size
self
.
_device_count
=
device_count
self
.
_n_head
=
n_head
self
.
_use_token_batch
=
use_token_batch
self
.
_use_token_batch
=
use_token_batch
self
.
_sort_type
=
sort_type
self
.
_sort_type
=
sort_type
self
.
_clip_last_batch
=
clip_last_batch
self
.
_clip_last_batch
=
clip_last_batch
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
self
.
_shuffle_seed
=
shuffle_seed
self
.
_shuffle_batch
=
shuffle_batch
self
.
_shuffle_batch
=
shuffle_batch
self
.
_min_length
=
min_length
self
.
_min_length
=
min_length
self
.
_max_length
=
max_length
self
.
_max_length
=
max_length
self
.
_field_delimiter
=
field_delimiter
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
self
.
load_src_trg_ids
(
fpattern
,
tar_fname
)
unk_mark
)
self
.
_random
=
np
.
random
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
self
.
_random
.
seed
(
seed
)
def
load_src_trg_ids
(
self
,
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
def
load_src_trg_ids
(
self
,
fpattern
,
tar_fname
):
unk_mark
):
converters
=
[
converters
=
[
Converter
(
Converter
(
vocab
=
self
.
_src_vocab
,
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_
src_vocab
[
start_mark
]
,
beg
=
self
.
_
bos_idx
,
end
=
self
.
_
src_vocab
[
end_mark
]
,
end
=
self
.
_
eos_idx
,
unk
=
self
.
_
src_vocab
[
unk_mark
]
,
unk
=
self
.
_
unk_idx
,
delimiter
=
self
.
_token_delimiter
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
)
add_beg
=
False
)
]
]
...
@@ -226,9 +373,9 @@ class DataReader(object):
...
@@ -226,9 +373,9 @@ class DataReader(object):
converters
.
append
(
converters
.
append
(
Converter
(
Converter
(
vocab
=
self
.
_trg_vocab
,
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_
trg_vocab
[
start_mark
]
,
beg
=
self
.
_
bos_idx
,
end
=
self
.
_
trg_vocab
[
end_mark
]
,
end
=
self
.
_
eos_idx
,
unk
=
self
.
_
trg_vocab
[
unk_mark
]
,
unk
=
self
.
_
unk_idx
,
delimiter
=
self
.
_token_delimiter
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
True
))
add_beg
=
True
))
...
@@ -254,9 +401,9 @@ class DataReader(object):
...
@@ -254,9 +401,9 @@ class DataReader(object):
if
tar_fname
is
None
:
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
"r"
)
f
=
tarfile
.
open
(
fpaths
[
0
],
"r
b
"
)
for
line
in
f
.
extractfile
(
tar_fname
):
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
yield
fields
...
@@ -267,9 +414,7 @@ class DataReader(object):
...
@@ -267,9 +414,7 @@ class DataReader(object):
with
open
(
fpath
,
"rb"
)
as
f
:
with
open
(
fpath
,
"rb"
)
as
f
:
for
line
in
f
:
for
line
in
f
:
if
six
.
PY3
:
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
line
=
line
.
decode
(
"utf8"
,
errors
=
"ignore"
)
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
yield
fields
...
@@ -279,23 +424,20 @@ class DataReader(object):
...
@@ -279,23 +424,20 @@ class DataReader(object):
word_dict
=
{}
word_dict
=
{}
with
open
(
dict_path
,
"rb"
)
as
fdict
:
with
open
(
dict_path
,
"rb"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
for
idx
,
line
in
enumerate
(
fdict
):
if
six
.
PY3
:
line
=
line
.
decode
(
"utf8"
,
errors
=
"ignore"
)
if
reverse
:
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
word_dict
[
idx
]
=
line
.
strip
(
b
"
\n
"
)
else
:
else
:
word_dict
[
line
.
strip
(
"
\n
"
)]
=
idx
word_dict
[
line
.
strip
(
b
"
\n
"
)]
=
idx
return
word_dict
return
word_dict
def
batch_generator
(
self
):
def
batch_generator
(
self
,
batch_size
,
use_token_batch
):
def
__impl__
():
# global sort or global shuffle
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
infos
=
sorted
(
self
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
else
:
else
:
if
self
.
_shuffle
:
if
self
.
_shuffle
:
infos
=
self
.
_sample_infos
infos
=
self
.
_sample_infos
if
self
.
_shuffle_seed
is
not
None
:
self
.
_random
.
seed
(
self
.
_shuffle_seed
)
self
.
_random
.
shuffle
(
infos
)
self
.
_random
.
shuffle
(
infos
)
else
:
else
:
infos
=
self
.
_sample_infos
infos
=
self
.
_sample_infos
...
@@ -313,8 +455,8 @@ class DataReader(object):
...
@@ -313,8 +455,8 @@ class DataReader(object):
# concat batch
# concat batch
batches
=
[]
batches
=
[]
batch_creator
=
TokenBatchCreator
(
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
batch_size
)
if
use_token_batch
else
SentenceBatchCreator
(
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_
batch_size
)
batch_size
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
batch_creator
)
...
@@ -337,3 +479,72 @@ class DataReader(object):
...
@@ -337,3 +479,72 @@ class DataReader(object):
else
:
else
:
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
return
__impl__
@
staticmethod
def
stack
(
data_reader
,
count
,
clip_last
=
True
):
def
__impl__
():
res
=
[]
for
item
in
data_reader
():
res
.
append
(
item
)
if
len
(
res
)
==
count
:
yield
res
res
=
[]
if
len
(
res
)
==
count
:
yield
res
elif
not
clip_last
:
data
=
[]
for
item
in
res
:
data
+=
item
if
len
(
data
)
>
count
:
inst_num_per_part
=
len
(
data
)
//
count
yield
[
data
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
1
)]
for
i
in
range
(
count
)
]
return
__impl__
@
staticmethod
def
split
(
data_reader
,
count
):
def
__impl__
():
for
item
in
data_reader
():
inst_num_per_part
=
len
(
item
)
//
count
for
i
in
range
(
count
):
yield
item
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
1
)]
return
__impl__
def
data_generator
(
self
,
phase
,
place
=
None
):
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
src_pad_idx
=
trg_pad_idx
=
self
.
_eos_idx
bos_idx
=
self
.
_bos_idx
n_head
=
self
.
_n_head
data_reader
=
self
.
batch_generator
(
self
.
_batch_size
*
(
1
if
self
.
_use_token_batch
else
self
.
_device_count
),
self
.
_use_token_batch
)
if
not
self
.
_use_token_batch
:
# to make data on each device have similar token number
data_reader
=
self
.
split
(
data_reader
,
self
.
_device_count
)
def
__for_train__
():
for
data
in
data_reader
():
data_inputs
=
prepare_train_input
(
data
,
src_pad_idx
,
trg_pad_idx
,
n_head
)
yield
data_inputs
def
__for_predict__
():
for
data
in
data_reader
():
data_inputs
=
prepare_infer_input
(
data
,
src_pad_idx
,
bos_idx
,
n_head
,
place
)
yield
data_inputs
return
__for_train__
if
phase
==
"train"
else
__for_predict__
def
get_vocab_summary
(
self
):
return
len
(
self
.
_src_vocab
),
len
(
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
PaddleNLP/neural_machine_translation/transformer/train.py
浏览文件 @
80628bc6
此差异已折叠。
点击以展开。
PaddleNLP/neural_machine_translation/transformer/transformer.py
0 → 100644
浏览文件 @
80628bc6
此差异已折叠。
点击以展开。
PaddleNLP/neural_machine_translation/transformer/transformer.yaml
0 → 100644
浏览文件 @
80628bc6
# used for continuous evaluation
enable_ce
:
False
# The frequency to save trained models when training.
save_step
:
10000
# The frequency to fetch and print output when training.
print_step
:
100
# path of the checkpoint, to resume the previous training
init_from_checkpoint
:
"
"
# path of the pretrain model, to better solve the current task
init_from_pretrain_model
:
"
"
# path of trained parameter, to make prediction
init_from_params
:
"
trained_params/step_100000"
save_model_path
:
"
"
# the directory for saving checkpoints.
save_checkpoint
:
"
trained_ckpts"
# the directory for saving trained parameters.
save_param
:
"
trained_params"
# the directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
random_seed
:
None
# The pattern to match training data files.
training_file
:
"
wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match test data files.
predict_file
:
"
wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
output_file
:
"
predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath
:
"
wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The path of vocabulary file of target language.
trg_vocab_fpath
:
"
wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
# whether to use cuda
use_cuda
:
True
# args for reader, see reader.py for details
token_delimiter
:
"
"
use_token_batch
:
True
pool_size
:
200000
sort_type
:
"
pool"
shuffle
:
True
shuffle_batch
:
True
batch_size
:
4096
# Hyparams for training:
# the number of epoches for training
epoch
:
30
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
:
2.0
beta1
:
0.9
beta2
:
0.997
eps
:
1e-9
# the parameters for learning rate scheduling.
warmup_steps
:
8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps
:
0.1
# Hyparams for generation:
# the parameters for beam search.
beam_size
:
5
max_out_len
:
256
# the number of decoded sentences to output.
n_best
:
1
# Hyparams for model:
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size
:
10000
# size of target word dictionay
trg_vocab_size
:
10000
# index for <bos> token
bos_idx
:
0
# index for <eos> token
eos_idx
:
1
# index for <unk> token
unk_idx
:
2
# max length of sequences deciding the size of position encoding table.
max_length
:
256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
:
512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
:
2048
# the dimension that keys are projected to for dot-product attention.
d_key
:
64
# the dimension that values are projected to for dot-product attention.
d_value
:
64
# number of head used in multi-head attention.
n_head
:
8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer
:
6
# dropout rates of different modules.
prepostprocess_dropout
:
0.1
attention_dropout
:
0.1
relu_dropout
:
0.1
# to process before each sub-layer
preprocess_cmd
:
"
n"
# layer normalization
# to process after each sub-layer
postprocess_cmd
:
"
da"
# dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing
:
True
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录